From a1d765e2a70c0fe756db3943fc357a40e3b948c4 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 2 Feb 2026 17:01:50 -0800 Subject: [PATCH 1/4] Remove jaxtyping dependency Summary: Gets rid of the jaxtyping dependency as proposed in #113 (due to limited value and causing dependency issues in some places). Test Plan: unit tests --- .conda/meta.yaml | 1 - .github/workflows/deploy.yml | 2 +- .github/workflows/pull_request.yml | 5 - .github/workflows/push_to_main.yml | 2 +- .../workflows/run_type_checked_test_suite.yml | 36 -- .gitignore | 3 + .hooks/check_type_hints.sh | 9 - .hooks/propagate_type_hints.py | 121 ------- .hooks/propagate_type_hints.sh | 7 - .pre-commit-config.yaml | 8 - CONTRIBUTING.md | 38 +- docs/source/conf.py | 28 -- linear_operator/operators/_linear_operator.py | 335 ++++++++++-------- .../operators/added_diag_linear_operator.py | 27 +- .../operators/batch_repeat_linear_operator.py | 53 +-- .../operators/block_diag_linear_operator.py | 61 ++-- .../block_interleaved_linear_operator.py | 43 +-- .../operators/block_linear_operator.py | 23 +- .../operators/cat_linear_operator.py | 37 +- .../operators/chol_linear_operator.py | 61 ++-- .../operators/constant_mul_linear_operator.py | 37 +- .../operators/dense_linear_operator.py | 41 ++- .../operators/diag_linear_operator.py | 175 +++++---- .../operators/identity_linear_operator.py | 113 +++--- .../operators/interpolated_linear_operator.py | 49 +-- .../operators/keops_linear_operator.py | 15 +- .../operators/kernel_linear_operator.py | 25 +- ...cker_product_added_diag_linear_operator.py | 43 +-- .../kronecker_product_linear_operator.py | 153 ++++---- ...ow_rank_root_added_diag_linear_operator.py | 39 +- .../low_rank_root_linear_operator.py | 13 +- .../operators/masked_linear_operator.py | 46 ++- .../operators/matmul_linear_operator.py | 29 +- .../operators/mul_linear_operator.py | 27 +- .../operators/permutation_linear_operator.py | 37 +- .../operators/psd_sum_linear_operator.py | 5 +- .../operators/root_linear_operator.py | 47 +-- .../operators/sum_batch_linear_operator.py | 9 +- .../sum_kronecker_linear_operator.py | 33 +- .../operators/sum_linear_operator.py | 39 +- .../operators/toeplitz_linear_operator.py | 29 +- .../operators/triangular_linear_operator.py | 93 ++--- .../operators/zero_linear_operator.py | 87 ++--- linear_operator/settings.py | 1 + .../test/type_checking_test_case.py | 97 ----- setup.py | 5 +- 46 files changed, 1023 insertions(+), 1164 deletions(-) delete mode 100644 .github/workflows/run_type_checked_test_suite.yml delete mode 100644 .hooks/check_type_hints.sh delete mode 100644 .hooks/propagate_type_hints.py delete mode 100755 .hooks/propagate_type_hints.sh delete mode 100644 linear_operator/test/type_checking_test_case.py diff --git a/.conda/meta.yaml b/.conda/meta.yaml index df19fbc4..3438d202 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -21,7 +21,6 @@ requirements: run: - python>=3.10 - - jaxtyping - mpmath>=0.19,<=1.3 - pytorch>=2.0 - scipy diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 7ba89566..a5434e2d 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -12,7 +12,7 @@ jobs: uses: ./.github/workflows/run_linter.yml run_test_suite: - uses: ./.github/workflows/run_type_checked_test_suite.yml + uses: ./.github/workflows/run_test_suite.yml deploy_pypi: runs-on: ubuntu-latest diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 593484ce..8c230b28 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -13,8 +13,3 @@ jobs: run_test_suite: uses: ./.github/workflows/run_test_suite.yml - - run_small_type_checked_test_suite: - uses: ./.github/workflows/run_type_checked_test_suite.yml - with: - files_to_test: "test/operators/test_dense_linear_operator.py test/operators/test_diag_linear_operator.py test/operators/test_kronecker_product_linear_operator.py" diff --git a/.github/workflows/push_to_main.yml b/.github/workflows/push_to_main.yml index 820297bd..efbbfc83 100644 --- a/.github/workflows/push_to_main.yml +++ b/.github/workflows/push_to_main.yml @@ -12,4 +12,4 @@ jobs: uses: ./.github/workflows/run_linter.yml run_test_suite: - uses: ./.github/workflows/run_type_checked_test_suite.yml + uses: ./.github/workflows/run_test_suite.yml diff --git a/.github/workflows/run_type_checked_test_suite.yml b/.github/workflows/run_type_checked_test_suite.yml deleted file mode 100644 index 1ba35b97..00000000 --- a/.github/workflows/run_type_checked_test_suite.yml +++ /dev/null @@ -1,36 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Run Type Checked Test Suite - -on: - workflow_call: - inputs: - files_to_test: - required: false - type: string - -jobs: - run_type_checked_unit_tests: - runs-on: ubuntu-latest - strategy: - matrix: - pytorch-version: ["latest", "stable"] - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: "3.10" - - name: Install dependencies - run: | - if [[ ${{ matrix.pytorch-version }} = "latest" ]]; then - pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - else - pip install "numpy<2" # Numpy 2.0 is not fully supported until PyTorch 2.2 - pip install torch==2.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - fi - pip install -e ".[test]" - - name: Run unit tests - run: | - pytest ${{ inputs.files_to_test }} --jaxtyping-packages=typeguard.typechecked diff --git a/.gitignore b/.gitignore index 14f691dd..2717ddf4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Claude +CLAUDE.md + # Atom plugin files and ctags .ftpconfig .ftpconfig.cson diff --git a/.hooks/check_type_hints.sh b/.hooks/check_type_hints.sh deleted file mode 100644 index d4783a9f..00000000 --- a/.hooks/check_type_hints.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env bash -echo "HI" 1>&2 -CHANGED_FILES=$(git diff --cached --name-only | grep linear_operator/operators) -echo $CHANGED_FILES 1>&2 -if [[ -n "$CHANGED_FILES" ]]; then - python ./.hooks/propagate_type_hints.py -else - echo "NO CHANGED FILES" 1>&2 -fi diff --git a/.hooks/propagate_type_hints.py b/.hooks/propagate_type_hints.py deleted file mode 100644 index cde4bc83..00000000 --- a/.hooks/propagate_type_hints.py +++ /dev/null @@ -1,121 +0,0 @@ -# Propagate type hints & signatures defined in _linear_operator.py to derived classes. -# Here we leverage libcst which can preserve the original whitespace & formatting of the file -# The idea is that we only want to change the type hints. -# This way we can enforce consistency between the base class signature and derived signatures. - -import os -from pathlib import Path -from typing import List, Optional, Tuple, TypedDict - -import libcst as cst - - -class Annotations(TypedDict): - key: Tuple[str, ...] # key: tuple of canonical class/function name - value: Tuple[cst.Parameters, Optional[cst.Annotation]] # value: (params, returns) - - -class TypingCollector(cst.CSTVisitor): - def __init__(self) -> None: - # stack for storing the canonical name of the current function - self.stack: List[Tuple[str, ...]] = [] - # store the annotations - self.annotations: Annotations = {} - - def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: - self.stack.append(node.name.value) - - def leave_ClassDef(self, node: cst.ClassDef) -> None: - self.stack.pop() - - def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: - self.stack.append(node.name.value) - self.annotations[tuple(self.stack)] = (node.params, node.returns) - return False # pyi files don't support inner functions, return False to stop the traversal. - - def leave_FunctionDef(self, node: cst.FunctionDef) -> None: - self.stack.pop() - - -class TypingTransformer(cst.CSTTransformer): - - # List of LinearOperator functions we do not want to propagate the signature from - excluded_functions = ["__init__", "_check_args", "__torch_function__"] - - def __init__(self, annotations: Annotations): - # stack for storing the canonical name of the current function - self.stack: List[Tuple[str, ...]] = [] - # store the annotations - self.annotations: Annotations = annotations - - def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: - self.stack.append(node.name.value) - - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: - self.stack.pop() - return updated_node - - def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: - self.stack.append(node.name.value) - return False # pyi files don't support inner functions, return False to stop the traversal. - - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: - key = tuple(self.stack) - if key[-1] in TypingTransformer.excluded_functions: - return updated_node - try: - if original_node.params.params[0].name.value != "self": # Assume this is not a class method - return updated_node - except Exception: - return updated_node - key = ("LinearOperator", key[-1]) - self.stack.pop() - if key in self.annotations: - annotations = self.annotations[key] - return updated_node.with_changes(params=annotations[0], returns=annotations[1]) - return updated_node - - -def collect_base_type_hints(base_filename: Path) -> TypingCollector: - base_tree = cst.parse_module(base_filename.read_text()) - base_visitor = TypingCollector() - base_tree.visit(base_visitor) - return base_visitor - - -def copy_base_type_hints_to_derived(target: Path, base_visitor: TypingCollector) -> cst.Module: - source_tree = cst.parse_module(target.read_text()) - transformer = TypingTransformer(base_visitor.annotations) - modified_tree = source_tree.visit(transformer) - return modified_tree - - -def main(): - directory = "linear_operator/operators" - base_filename = Path(directory) / "_linear_operator.py" - base_visitor = collect_base_type_hints(base_filename) - - os.environ["TYPE_HINTS_PROPAGATED"] = "0" - changed_files = [] - - pathlist = Path(directory).glob("*.py") - for path in pathlist: - if path.name[0] == "_": - continue - target = path - target_out = path - original_code = target.read_text() - modified_code = copy_base_type_hints_to_derived(target, base_visitor).code - if original_code != modified_code: - changed_files.append(path) - with open(target_out, "w") as f: - f.write(modified_code) - - if len(changed_files): - print("The following files have been changed:") # noqa T201 - for changed_file in changed_files: - print(f" - {changed_file}") # noqa T201 - os.environ["TYPE_HINTS_PROPAGATED"] = "1" - - -main() diff --git a/.hooks/propagate_type_hints.sh b/.hooks/propagate_type_hints.sh deleted file mode 100755 index 5693407e..00000000 --- a/.hooks/propagate_type_hints.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash -if [[ -n "$(echo $@ | grep linear_operator/operators)" ]]; then - python ./.hooks/propagate_type_hints.py - if [[ $TYPE_HINTS_PROPAGATED = 1 ]]; then - exit 2 - fi -fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce1b3635..1f15949b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,11 +38,3 @@ repos: hooks: - id: forbid-crlf - id: forbid-tabs -- repo: local - hooks: - - id: propagate-type-hints - name: Propagate Type Hints - entry: ./.hooks/propagate_type_hints.sh - language: script - pass_filenames: true - require_serial: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fe36716a..c1ffb573 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -39,38 +39,20 @@ We use [standard sphinx docstrings](https://sphinx-rtd-tutorial.readthedocs.io/e LinearOperator aims to be fully typed using Python 3.10+ [type hints](https://www.python.org/dev/peps/pep-0484/). We expect any contributions to also use proper type annotations. -We are using [jaxtyping](https://github.com/google/jaxtyping) to help us be declarative about the dimension sizes used -in the LinearOperator methods. -The use of [jaxtyping](https://github.com/google/jaxtyping) makes it clearer what the functions are doing algebraically -and where broadcasting is happening. - -These type hints are checked in the unit tests by using -[typeguard](https://github.com/agronholm/typeguard) to perform run-time -checking of the signatures to make sure they are accurate. -The signatures are written into the base linear operator class in `_linear_oparator.py`. -These signatures are then copied to the derived classes by running the script -`python ./.hooks/propagate_type_hints.py`. -This is done for: -1. Consistency. Make sure the derived implementations are following the promised interface. -2. Visibility. Make it easy to see what the expected signature is, along with dimensions. Repeating the signature in the derived classes enhances readability. -3. Necessity. The way that jaxtyping and typeguard are written, they won't run type checks unless type annotations are present in the derived method signature. - -In short, if you want to update the type hints, update the code in the LinearOperator class in -`_linear_oparator.py` then run `python ./.hooks/propagate_type_hints.py` to copy the new signature to the derived -classes. - -### Unit Tests -#### With type checking (slower, but more thorough) -To run the unittests with type checking, run -```bash -pytest --jaxtyping-packages=linear_operator,typeguard.typechecked +Dimension information is documented using inline comments with the format: +```python +def _matmul( + self: LinearOperator, # shape: (*batch, M, N) + rhs: Tensor, # shape: (*batch2, N, C) or (*batch2, N) +) -> Tensor: # shape: (..., M, C) or (..., M) ``` -- To run tests within a specific directory, run (e.g.) `pytest test/operators --jaxtyping-packages=linear_operator,typeguard.typechecked`. -- To run a specific file, run (e.g.) `pytest test/operators/test_matmul_linear_operator.py --jaxtyping-packages=linear_operator,typeguard.typechecked`. +This convention makes it clear what the functions are doing algebraically +and where broadcasting is happening. + +### Unit Tests -#### Without type checking (faster, but less thorough) We use python's `unittest` to run unit tests: ```bash python -m unittest diff --git a/docs/source/conf.py b/docs/source/conf.py index ff6f39cb..3241eacb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,6 @@ import warnings from typing import ForwardRef -import jaxtyping import sphinx_rtd_theme # noqa @@ -123,27 +122,6 @@ def _convert_internal_and_external_class_to_strings(annotation): return res -# Convert jaxtyping dimensions into strings -def _dim_to_str(dim): - if isinstance(dim, jaxtyping._array_types._NamedVariadicDim): - return "..." - elif isinstance(dim, jaxtyping._array_types._FixedDim): - res = str(dim.size) - if dim.broadcastable: - res = "#" + res - return res - elif isinstance(dim, jaxtyping._array_types._SymbolicDim): - expr = dim.elem - return f"({expr})" - elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis - return "..." - else: - res = str(dim.name) - if dim.broadcastable: - res = "#" + res - return res - - # Function to format type hints def _process(annotation, config): """ @@ -156,12 +134,6 @@ def _process(annotation, config): if type(annotation) == str: return annotation - # Jaxtyping: shaped tensors or linear operator - elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__: - cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type) - shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims]) - return f"{cls_annotation} ({shape})" - # Convert Ellipsis into "..." elif annotation == Ellipsis: return "..." diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index f89e806d..3e301032 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -15,11 +15,6 @@ import numpy as np import torch -try: - # optional library for advanced type signatures - from jaxtyping import Float, Int -except ImportError: - pass from torch import Tensor import linear_operator @@ -173,9 +168,9 @@ def __init__(self, *args, **kwargs): @abstractmethod def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) r""" Performs a matrix multiplication :math:`\mathbf KM` with the (... x M x N) matrix :math:`\mathbf K` that this LinearOperator represents. Should behave as @@ -209,7 +204,9 @@ def _size(self) -> torch.Size: raise NotImplementedError("The class {} requires a _size function!".format(self.__class__.__name__)) @abstractmethod - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) """ Transposes non-batch dimensions (e.g. last two) Implement this method, rather than transpose() or t(). @@ -396,8 +393,8 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O return tuple(grads) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) """ Expands along batch dimensions. Return size will be *batch_shape x *matrix_shape. @@ -483,7 +480,9 @@ def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> def _kwargs(self) -> Dict[str, Any]: return {**self._differentiable_kwargs, **self._nondifferentiable_kwargs} - def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]: + def _approx_diagonal( + self: LinearOperator, # shape: (*batch, N, N) + ) -> torch.Tensor: # shape: (*batch, N) """ (Optional) returns an (approximate) diagonal of the matrix @@ -499,8 +498,8 @@ def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.T @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) """ (Optional) Cholesky-factorizes the LinearOperator @@ -529,10 +528,10 @@ def _cholesky( return TriangularLinearOperator(cholesky, upper=upper) def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) """ (Optional) Assuming that `self` is a Cholesky factor, computes the cholesky solve. @@ -561,7 +560,9 @@ def _choose_root_method(self) -> str: return "cholesky" return "lanczos" - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) r""" As :func:`torch._diagonal`, returns the diagonal of the matrix :math:`\mathbf A` this LinearOperator represents as a vector. @@ -575,8 +576,8 @@ def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, ".. return self[..., row_col_iter, row_col_iter] def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Multiplies the LinearOperator by a constant. @@ -592,9 +593,9 @@ def _mul_constant( return ConstantMulLinearOperator(self, other) def _mul_matrix( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[torch.Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"]], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + ) -> LinearOperator: # shape: (..., M, N) r""" Multiplies the LinearOperator by a (batch of) matrices. @@ -686,8 +687,8 @@ def _prod_batch(self, dim: int) -> LinearOperator: return res def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) """ Returns the (usually low-rank) root of a LinearOperator of a PSD matrix. @@ -720,10 +721,10 @@ def _root_decomposition_size(self) -> int: return settings.max_root_decomposition_size.value() def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) r""" Returns the (usually low-rank) inverse root of a LinearOperator of a PSD matrix. @@ -778,15 +779,15 @@ def _set_requires_grad(self, val: bool) -> None: arg.requires_grad_(val) def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: r""" @@ -862,8 +863,8 @@ def _sum_batch(self, dim: int) -> LinearOperator: @cached(name="svd") def _svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) """Method that allows implementing special-cased SVD computation. Should not be called directly""" # Using symeig is preferable here for psd LinearOperators. # Will need to overwrite this function for non-psd LinearOperators. @@ -875,10 +876,10 @@ def _svd( return U, S, V def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) r""" Method that allows implementing special-cased symeig computation. Should not be called directly """ @@ -900,9 +901,9 @@ def _symeig( return evals, evecs def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) r""" Performs a transpose matrix multiplication :math:`\mathbf K^\top \mathbf M` with the (... x M x N) matrix :math:`\mathbf K` that this LinearOperator represents. @@ -927,10 +928,10 @@ def abs(self) -> LinearOperator: @_implements_symmetric(torch.add) def add( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch M N"], Float[LinearOperator, "*batch M N"]], + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch, M, N) alpha: float = None, - ) -> Float[LinearOperator, "*batch M N"]: + ) -> LinearOperator: # shape: (*batch, M, N) r""" Each element of the tensor :attr:`other` is multiplied by the scalar :attr:`alpha` and added to each element of the :obj:`~linear_operator.operators.LinearOperator`. @@ -950,9 +951,9 @@ def add( return self + alpha * other def add_diagonal( - self: Float[LinearOperator, "*batch N N"], - diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, # shape: (*batch, N, N) + diag: torch.Tensor, # shape: (..., N) or (..., 1) or () + ) -> LinearOperator: # shape: (*batch, N, N) r""" Adds an element to the diagonal of the matrix. @@ -1000,8 +1001,8 @@ def add_diagonal( return AddedDiagLinearOperator(self, diag_tensor) def add_jitter( - self: Float[LinearOperator, "*batch N N"], jitter_val: float = 1e-3 - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, jitter_val: float = 1e-3 # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) r""" Adds jitter (i.e., a small diagonal component) to the matrix this LinearOperator represents. @@ -1016,13 +1017,13 @@ def add_jitter( return self.add_diagonal(diag) def add_low_rank( - self: Float[LinearOperator, "*batch N N"], - low_rank_mat: Union[Float[Tensor, "... N _"], Float[LinearOperator, "... N _"]], + self: LinearOperator, # shape: (*batch, N, N) + low_rank_mat: Union[Tensor, LinearOperator], # shape: (..., N, _) root_decomp_method: Optional[str] = None, root_inv_decomp_method: Optional[str] = None, generate_roots: Optional[bool] = True, **root_decomp_kwargs, - ) -> Float[LinearOperator, "*batch N N"]: # returns SumLinearOperator + ) -> LinearOperator: # returns SumLinearOperator # shape: (*batch, N, N) r""" Adds a low rank matrix to the matrix that this LinearOperator represents, e.g. computes :math:`\mathbf A + \mathbf{BB}^\top`. @@ -1148,13 +1149,13 @@ def batch_shape(self) -> torch.Size: return self.shape[:-2] def cat_rows( - self: Float[LinearOperator, "... M N"], - cross_mat: Float[torch.Tensor, "... O N"], - new_mat: Float[torch.Tensor, "... O O"], + self: LinearOperator, # shape: (..., M, N) + cross_mat: torch.Tensor, # shape: (..., O, N) + new_mat: torch.Tensor, # shape: (..., O, O) generate_roots: bool = True, generate_inv_roots: bool = True, **root_decomp_kwargs, - ) -> Float[LinearOperator, "... M+O N+O"]: + ) -> LinearOperator: # shape: (..., M+O, N+O) r""" Concatenates new rows and columns to the matrix that this LinearOperator represents, e.g. @@ -1300,8 +1301,8 @@ def cat_rows( @_implements(torch.linalg.cholesky) def cholesky( - self: Float[LinearOperator, "*batch N N"], upper: bool = False - ) -> Float[LinearOperator, "*batch N N"]: # returns TriangularLinearOperator + self: LinearOperator, upper: bool = False # shape: (*batch, N, N) + ) -> LinearOperator: # returns TriangularLinearOperator # shape: (*batch, N, N) """ Cholesky-factorizes the LinearOperator. @@ -1314,7 +1315,9 @@ def cholesky( return chol @_implements(torch.clone) - def clone(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def clone( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns clone of the LinearOperator (with clones of all underlying tensors) """ @@ -1322,7 +1325,9 @@ def clone(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "* kwargs = {key: val.clone() if hasattr(val, "clone") else val for key, val in self._kwargs.items()} return self.__class__(*args, **kwargs) - def cpu(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def cpu( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns new LinearOperator identical to :attr:`self`, but on the CPU. """ @@ -1341,8 +1346,8 @@ def cpu(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*ba return self.__class__(*new_args, **new_kwargs) def cuda( - self: Float[LinearOperator, "*batch M N"], device_id: Optional[str] = None - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.nn.Module.cuda`. @@ -1366,7 +1371,9 @@ def cuda( def device(self) -> Optional[torch.device]: return self._args[0].device - def detach(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def detach( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Removes the LinearOperator from the current computation graph. (In practice, this function removes all Tensors that make up the @@ -1378,7 +1385,9 @@ def detach(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, " ) return self.__class__(*detached_args, **detached_kwargs) - def detach_(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def detach_( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ An in-place version of :meth:`detach`. """ @@ -1392,8 +1401,8 @@ def detach_(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, @_implements(torch.diagonal) def diagonal( - self: Float[LinearOperator, "*batch N N"], offset: int = 0, dim1: int = -2, dim2: int = -1 - ) -> Float[Tensor, "*batch N"]: + self: LinearOperator, offset: int = 0, dim1: int = -2, dim2: int = -1 # shape: (*batch, N, N) + ) -> Tensor: # shape: (*batch, N) r""" As :func:`torch.diagonal`, returns the diagonal of the matrix :math:`\mathbf A` this LinearOperator represents as a vector. @@ -1420,8 +1429,8 @@ def diagonal( @cached(name="diagonalization") def diagonalization( - self: Float[LinearOperator, "*batch N N"], method: Optional[str] = None - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) """ Returns a (usually partial) diagonalization of a symmetric PSD matrix. Options are either "lanczos" or "symeig". "lanczos" runs Lanczos while @@ -1487,8 +1496,8 @@ def div(self, other: Union[float, torch.Tensor]) -> LinearOperator: return self.mul(1.0 / other) def double( - self: Float[LinearOperator, "*batch M N"], device_id: Optional[str] = None - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.Tensor.double`. @@ -1502,8 +1511,8 @@ def dtype(self) -> Optional[torch.dtype]: @_implements(torch.linalg.eigh) def eigh( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[Tensor, "*batch N"], Optional[Float[LinearOperator, "*batch N N"]]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, N), (*batch, N, N) """ Compute the symmetric eigendecomposition of the linear operator. This can be very slow for large tensors. @@ -1525,10 +1534,8 @@ def eigh( @_implements(torch.linalg.eigvalsh) def eigvalsh( - self: Float[LinearOperator, "*batch N N"] - ) -> Union[ - Float[Tensor, "*batch N"], Tuple[Float[Tensor, "*batch N"], Optional[Float[LinearOperator, "*batch N N"]]] - ]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Union[Tensor, Tuple[Tensor, Optional[LinearOperator]]]: # shape: (*batch, N) or (*batch, N, N) """ Compute the eigenvalues of symmetric linear operator. This can be very slow for large tensors. @@ -1555,7 +1562,9 @@ def evaluate_kernel(self): return self.representation_tree()(*self.representation()) @_implements(torch.exp) - def exp(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def exp( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) # Only implemented by some LinearOperator subclasses # We define it here so that we can map the torch function torch.exp to the LinearOperator method raise NotImplementedError(f"torch.exp({self.__class__.__name__}) is not implemented.") @@ -1598,8 +1607,8 @@ def expand(self, *sizes: Union[torch.Size, int]) -> LinearOperator: return res def float( - self: Float[LinearOperator, "*batch M N"], device_id: Optional[str] = None - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.Tensor.float`. @@ -1608,8 +1617,8 @@ def float( return self.type(torch.float) def half( - self: Float[LinearOperator, "*batch M N"], device_id: Optional[str] = None - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.Tensor.half`. @@ -1618,10 +1627,10 @@ def half( return self.type(torch.half) def inv_quad( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]], + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Tensor, # shape: (*batch, N, M) or (*batch, N) reduce_inv_quad: bool = True, - ) -> Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"]]: + ) -> Tensor: # shape: (*batch, M) or (*batch) r""" Computes an inverse quadratic form (w.r.t self) with several right hand sides, i.e: @@ -1669,13 +1678,13 @@ def inv_quad( return inv_quad_term def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: r""" Calls both :func:`inv_quad` and :func:`logdet` on a positive @@ -1787,7 +1796,9 @@ def inv_quad_logdet( return inv_quad_term, logdet_term @_implements(torch.inverse) - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) # Only implemented by some LinearOperator subclasses # We define it here so that we can map the torch function torch.inverse to the LinearOperator method raise NotImplementedError(f"torch.inverse({self.__class__.__name__}) is not implemented.") @@ -1801,13 +1812,17 @@ def isclose(self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bo return self._isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan) @_implements(torch.log) - def log(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def log( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) # Only implemented by some LinearOperator subclasses # We define it here so that we can map the torch function torch.log to the LinearOperator method raise NotImplementedError(f"torch.log({self.__class__.__name__}) is not implemented.") @_implements(torch.logdet) - def logdet(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, " *batch"]: + def logdet( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch) r""" Computes the log determinant :math:`\log \vert \mathbf A \vert`. """ @@ -1816,9 +1831,9 @@ def logdet(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, " *batch" @_implements(torch.matmul) def matmul( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], - ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) r""" Performs :math:`\mathbf A \mathbf B`, where :math:`\mathbf A \in \mathbb R^{M \times N}` is the LinearOperator and :math:`\mathbf B` @@ -1843,7 +1858,9 @@ def matrix_shape(self) -> torch.Size: return torch.Size(self.shape[-2:]) @property - def mT(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def mT( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) """ Alias of transpose(-1, -2) """ @@ -1851,9 +1868,9 @@ def mT(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*bat @_implements_symmetric(torch.mul) def mul( - self: Float[LinearOperator, "*batch M N"], - other: Union[float, Float[Tensor, "*batch2 M N"], Float[LinearOperator, "*batch2 M N"]], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[float, Tensor, LinearOperator], # shape: (*batch2, M, N) + ) -> LinearOperator: # shape: (..., M, N) """ Multiplies the matrix by a constant, or elementwise the matrix by another matrix. @@ -1944,11 +1961,11 @@ def permute(self, *dims: Union[int, Tuple[int, ...]]) -> LinearOperator: return self._permute_batch(*dims[:-2]) def pivoted_cholesky( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) rank: int, error_tol: Optional[float] = None, return_pivots: bool = False, - ) -> Union[Float[Tensor, "*batch N R"], Tuple[Float[Tensor, "*batch N R"], Int[Tensor, "*batch N"]]]: + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, R), (*batch, N, R), (*batch, N) r""" Performs a partial pivoted Cholesky factorization of the (positive definite) LinearOperator. :math:`\mathbf L \mathbf L^\top = \mathbf K`. @@ -2108,9 +2125,9 @@ def reshape(self, *sizes: Union[torch.Size, int, Tuple[int, ...]]) -> LinearOper @_implements_second_arg(torch.matmul) def rmatmul( - self: Float[LinearOperator, "... M N"], - other: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]], - ) -> Union[Float[Tensor, "... P N"], Float[Tensor, "N"], Float[LinearOperator, "... P N"]]: + self: LinearOperator, # shape: (..., M, N) + other: Union[Tensor, LinearOperator], # shape: (..., P, M) or (..., M) + ) -> Union[Tensor, LinearOperator]: # shape: (..., P, N) or (N) r""" Performs :math:`\mathbf B \mathbf A`, where :math:`\mathbf A \in \mathbb R^{M \times N}` is the LinearOperator and :math:`\mathbf B` @@ -2127,8 +2144,8 @@ def rmatmul( @cached(name="root_decomposition") def root_decomposition( - self: Float[LinearOperator, "*batch N N"], method: Optional[str] = None - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) r""" Returns a (usually low-rank) root decomposition linear operator of the PSD LinearOperator :math:`\mathbf A`. This can be used for sampling from a Gaussian distribution, or for obtaining a @@ -2190,11 +2207,11 @@ def root_decomposition( @cached(name="root_inv_decomposition") def root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, method: Optional[str] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) r""" Returns a (usually low-rank) inverse root decomposition linear operator of the PSD LinearOperator :math:`\mathbf A`. @@ -2294,10 +2311,10 @@ def shape(self) -> torch.Size: @_implements(torch.linalg.solve) def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) r""" Computes a linear solve (w.r.t self = :math:`\mathbf A`) with right hand side :math:`\mathbf R`. I.e. computes @@ -2383,16 +2400,18 @@ def solve_triangular( raise NotImplementedError(f"torch.linalg.solve_triangular({self.__class__.__name__}) is not implemented.") @_implements(torch.sqrt) - def sqrt(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def sqrt( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) # Only implemented by some LinearOperator subclasses # We define it here so that we can map the torch function torch.sqrt to the LinearOperator method raise NotImplementedError(f"torch.sqrt({self.__class__.__name__}) is not implemented.") def sqrt_inv_matmul( - self: Float[LinearOperator, "*batch N N"], - rhs: Float[Tensor, "*batch N P"], - lhs: Optional[Float[Tensor, "*batch O N"]] = None, - ) -> Union[Float[Tensor, "*batch N P"], Tuple[Float[Tensor, "*batch O P"], Float[Tensor, "*batch O"]]]: + self: LinearOperator, # shape: (*batch, N, N) + rhs: Tensor, # shape: (*batch, N, P) + lhs: Optional[Tensor] = None, # shape: (*batch, O, N) + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) r""" If the LinearOperator :math:`\mathbf A` is positive definite, computes @@ -2453,10 +2472,10 @@ def squeeze(self, dim: int) -> Union[LinearOperator, torch.Tensor]: @_implements(torch.sub) def sub( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch M N"], Float[LinearOperator, "*batch M N"]], + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch, M, N) alpha: float = None, - ) -> Float[LinearOperator, "*batch M N"]: + ) -> LinearOperator: # shape: (*batch, M, N) r""" Each element of the tensor :attr:`other` is multiplied by the scalar :attr:`alpha` and subtracted to each element of the :obj:`~linear_operator.operators.LinearOperator`. @@ -2519,8 +2538,8 @@ def sum(self, dim: Optional[int] = None) -> Union[LinearOperator, torch.Tensor]: raise ValueError("Invalid dim ({}) for LinearOperator of size {}".format(orig_dim, self.shape)) def svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) r""" Compute the SVD of the linear operator :math:`\mathbf A \in \mathbb R^{M \times N}` s.t. :math:`\mathbf A = \mathbf{U S V^\top}`. @@ -2539,8 +2558,8 @@ def svd( @_implements(torch.linalg.svd) def _torch_linalg_svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) r""" A version of self.svd() that matches the torch.linalg.svd API. @@ -2553,13 +2572,17 @@ def _torch_linalg_svd( return U, S, V.mT @property - def T(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def T( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) """ Alias of t() """ return self.t() - def t(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def t( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) """ Alias of :meth:`~linear_operator.LinearOperator.transpose` for 2D LinearOperator. (Tranposes the two dimensions.) @@ -2568,7 +2591,11 @@ def t(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batc raise RuntimeError("Cannot call t for more than 2 dimensions") return self.transpose(0, 1) - def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]: + def to( + self: LinearOperator, # shape: (*batch, M, N) + *args, + **kwargs, + ) -> LinearOperator: # shape: (*batch, M, N) """ A device-agnostic method of moving the LinearOperator to the specified device or dtype. This method functions just like :meth:`torch.Tensor.to`. @@ -2592,7 +2619,9 @@ def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[Line return self.__class__(*new_args, **new_kwargs) @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) """ Explicitly evaluates the matrix this LinearOperator represents. This function should return a :obj:`torch.Tensor` storing an exact representation of this LinearOperator. @@ -2703,8 +2732,8 @@ def unsqueeze(self, dim: int) -> LinearOperator: # TODO: replace this method with something like sqrt_matmul. def zero_mean_mvn_samples( - self: Float[LinearOperator, "*batch N N"], num_samples: int - ) -> Float[Tensor, "num_samples *batch N"]: + self: LinearOperator, num_samples: int # shape: (*batch, N, N) + ) -> Tensor: # shape: (num_samples, *batch, N) r""" Assumes that the LinearOpeator :math:`\mathbf A` is a covariance matrix, or a batch of covariance matrices. @@ -2752,15 +2781,15 @@ def zero_mean_mvn_samples( return samples def __sub__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) return self + other.mul(-1) def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator from linear_operator.operators.dense_linear_operator import to_linear_operator from linear_operator.operators.diag_linear_operator import DiagLinearOperator @@ -2897,46 +2926,44 @@ def _isclose(self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: b return torch.isclose(to_dense(self), to_dense(other), rtol=rtol, atol=atol, equal_nan=equal_nan) def __matmul__( - self: Float[LinearOperator, "*batch M N"], - other: Union[ - Float[torch.Tensor, "*batch2 N D"], Float[torch.Tensor, "N"], Float[LinearOperator, "*batch2 N D"] - ], - ) -> Union[Float[torch.Tensor, "... M D"], Float[torch.Tensor, "... M"], Float[LinearOperator, "... M D"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[torch.Tensor, LinearOperator], # shape: (*batch2, N, D) or (N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., M, D) or (..., M) return self.matmul(other) @_implements_second_arg(torch.Tensor.matmul) def __rmatmul__( - self: Float[LinearOperator, "... M N"], - other: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]], - ) -> Union[Float[Tensor, "... P N"], Float[Tensor, "... N"], Float[LinearOperator, "... P N"]]: + self: LinearOperator, # shape: (..., M, N) + other: Union[Tensor, LinearOperator], # shape: (..., P, M) or (..., M) + ) -> Union[Tensor, LinearOperator]: # shape: (..., P, N) or (..., N) return self.rmatmul(other) @_implements_second_arg(torch.Tensor.mul) def __mul__( - self: Float[LinearOperator, "*batch #M #N"], - other: Union[Float[torch.Tensor, "*batch2 #M #N"], Float[LinearOperator, "*batch2 #M #N"], float], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (*batch, #M, #N) + other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + ) -> LinearOperator: # shape: (..., M, N) return self.mul(other) @_implements_second_arg(torch.Tensor.add) def __radd__( - self: Float[LinearOperator, "*batch #M #N"], - other: Union[Float[torch.Tensor, "*batch2 #M #N"], Float[LinearOperator, "*batch2 #M #N"], float], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (*batch, #M, #N) + other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + ) -> LinearOperator: # shape: (..., M, N) return self + other def __rmul__( - self: Float[LinearOperator, "*batch #M #N"], - other: Union[Float[torch.Tensor, "*batch2 #M #N"], Float[LinearOperator, "*batch2 #M #N"], float], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (*batch, #M, #N) + other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + ) -> LinearOperator: # shape: (..., M, N) return self.mul(other) @_implements_second_arg(torch.sub) @_implements_second_arg(torch.Tensor.sub) def __rsub__( - self: Float[LinearOperator, "*batch #M #N"], - other: Union[Float[torch.Tensor, "*batch2 #M #N"], Float[LinearOperator, "*batch2 #M #N"], float], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (*batch, #M, #N) + other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + ) -> LinearOperator: # shape: (..., M, N) return self.mul(-1) + other @classmethod diff --git a/linear_operator/operators/added_diag_linear_operator.py b/linear_operator/operators/added_diag_linear_operator.py index 71a82e46..a75342da 100644 --- a/linear_operator/operators/added_diag_linear_operator.py +++ b/linear_operator/operators/added_diag_linear_operator.py @@ -6,7 +6,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator import settings @@ -71,21 +70,21 @@ def __init__( self._r_cache = None def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return torch.addcmul(self._linear_op._matmul(rhs), self._diag_tensor._diag.unsqueeze(-1), rhs) def add_diagonal( - self: Float[LinearOperator, "*batch N N"], - diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, # shape: (*batch, N, N) + diag: torch.Tensor, # shape: (..., N) or (..., 1) or () + ) -> LinearOperator: # shape: (*batch, N, N) return self.__class__(self._linear_op, self._diag_tensor.add_diagonal(diag)) def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): @@ -186,8 +185,8 @@ def _init_cache_for_non_constant_diag(self, eye: Tensor, batch_shape: Union[torc @cached(name="svd") def _svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) if isinstance(self._diag_tensor, ConstantDiagLinearOperator): U, S_, V = self._linear_op.svd() S = S_ + self._diag_tensor._diagonal() @@ -195,10 +194,10 @@ def _svd( return super()._svd() def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) if isinstance(self._diag_tensor, ConstantDiagLinearOperator): evals_, evecs = self._linear_op._symeig(eigenvectors=eigenvectors) evals = evals_ + self._diag_tensor._diagonal() diff --git a/linear_operator/operators/batch_repeat_linear_operator.py b/linear_operator/operators/batch_repeat_linear_operator.py index 04ee478e..7188d32b 100644 --- a/linear_operator/operators/batch_repeat_linear_operator.py +++ b/linear_operator/operators/batch_repeat_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator import settings @@ -40,8 +39,8 @@ def __init__(self, base_linear_op, batch_repeat=torch.Size((1,))): @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator res = self.base_linear_op.cholesky(upper=upper)._tensor @@ -49,10 +48,10 @@ def _cholesky( return TriangularLinearOperator(res, upper=upper) def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) # TODO: Figure out how to deal with this with TriangularLinearOperator if returned by _cholesky output_shape = _matmul_broadcast_shape(self.shape, rhs.shape) if rhs.shape != output_shape: @@ -73,8 +72,8 @@ def _compute_batch_repeat_size( return batch_repeat def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) padding_dims = torch.Size(tuple(1 for _ in range(max(len(batch_shape) + 2 - self.base_linear_op.dim(), 0)))) current_batch_shape = padding_dims + self.base_linear_op.batch_shape return self.__class__( @@ -111,9 +110,9 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return new_linear_op._getitem(row_index, col_index, *batch_indices) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) output_shape = _matmul_broadcast_shape(self.shape, rhs.shape) # only attempt broadcasting if the non-batch dimensions are the same @@ -216,15 +215,15 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O return super()._bilinear_derivative(left_vecs, right_vecs) def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) return self.base_linear_op._root_decomposition().repeat(*self.batch_repeat, 1, 1) def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) return self.base_linear_op._root_inv_decomposition().repeat(*self.batch_repeat, 1, 1) def _size(self) -> torch.Size: @@ -234,7 +233,9 @@ def _size(self) -> torch.Size: res = torch.Size(repeated_batch_shape + self.base_linear_op.matrix_shape) return res - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__(self.base_linear_op._transpose_nonbatch(), batch_repeat=self.batch_repeat) def _unsqueeze_batch(self, dim: int) -> LinearOperator: @@ -250,18 +251,18 @@ def _unsqueeze_batch(self, dim: int) -> LinearOperator: return self.__class__(base_linear_op, batch_repeat=batch_repeat) def add_jitter( - self: Float[LinearOperator, "*batch N N"], jitter_val: float = 1e-3 - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, jitter_val: float = 1e-3 # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) return self.__class__(self.base_linear_op.add_jitter(jitter_val=jitter_val), batch_repeat=self.batch_repeat) def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if not self.is_square: raise RuntimeError( @@ -319,8 +320,8 @@ def repeat(self, *sizes: Union[int, Tuple[int, ...]]) -> LinearOperator: @cached(name="svd") def _svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) U_, S_, V_ = self.base_linear_op.svd() U = U_.repeat(*self.batch_repeat, 1, 1) S = S_.repeat(*self.batch_repeat, 1) @@ -328,10 +329,10 @@ def _svd( return U, S, V def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) evals, evecs = self.base_linear_op._symeig(eigenvectors=eigenvectors) evals = evals.repeat(*self.batch_repeat, 1) if eigenvectors: diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index 8abf3bd6..e5b97d0d 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -4,7 +4,6 @@ from typing import Callable, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -66,8 +65,8 @@ def num_blocks(self) -> int: return self.base_linear_op.size(-3) def _add_batch_dim( - self: Float[LinearOperator, "*batch1 M P"], other: Float[torch.Tensor, "*batch2 N C"] - ) -> Float[torch.Tensor, "batch2 ... C"]: + self: LinearOperator, other: torch.Tensor # shape: (*batch1, M, P) or (*batch2, N, C) + ) -> torch.Tensor: # shape: (batch2, ..., C) *batch_shape, num_rows, num_cols = other.shape batch_shape = list(batch_shape) @@ -77,24 +76,26 @@ def _add_batch_dim( @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator chol = self.__class__(self.base_linear_op.cholesky(upper=upper)) return TriangularLinearOperator(chol, upper=upper) def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) rhs = self._add_batch_dim(rhs) res = self.base_linear_op._cholesky_solve(rhs, upper=upper) res = self._remove_batch_dim(res) return res - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) res = self.base_linear_op._diagonal().contiguous() return res.view(*self.batch_shape, self.size(-1)) @@ -121,15 +122,15 @@ def _remove_batch_dim(self, other): return other def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_decomposition()) def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_inv_decomposition(initial_vectors)) def _size(self) -> torch.Size: @@ -140,15 +141,15 @@ def _size(self) -> torch.Size: return torch.Size(shape) def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: if num_tridiag: @@ -160,13 +161,13 @@ def _solve( return res def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if inv_quad_rhs is not None: inv_quad_rhs = self._add_batch_dim(inv_quad_rhs) @@ -185,9 +186,9 @@ def inv_quad_logdet( return inv_quad_res, logdet_res def matmul( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], - ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) from linear_operator.operators.diag_linear_operator import DiagLinearOperator # this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes @@ -203,8 +204,8 @@ def matmul( @cached(name="svd") def _svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) U, S, V = self.base_linear_op.svd() # Doesn't make much sense to sort here, o/w we lose the structure S = S.reshape(*S.shape[:-2], S.shape[-2:].numel()) @@ -214,10 +215,10 @@ def _svd( return U, S, V def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) evals, evecs = self.base_linear_op._symeig(eigenvectors=eigenvectors) # Doesn't make much sense to sort here, o/w we lose the structure evals = evals.reshape(*evals.shape[:-2], evals.shape[-2:].numel()) diff --git a/linear_operator/operators/block_interleaved_linear_operator.py b/linear_operator/operators/block_interleaved_linear_operator.py index ccfd6bb1..ba6f7d39 100644 --- a/linear_operator/operators/block_interleaved_linear_operator.py +++ b/linear_operator/operators/block_interleaved_linear_operator.py @@ -2,7 +2,6 @@ from typing import Callable, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -41,24 +40,26 @@ def _add_batch_dim(self, other): @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator chol = self.__class__(self.base_linear_op.cholesky(upper=upper)) return TriangularLinearOperator(chol, upper=upper) def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) rhs = self._add_batch_dim(rhs) res = self.base_linear_op._cholesky_solve(rhs, upper=upper) res = self._remove_batch_dim(res) return res - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) block_diag = self.base_linear_op._diagonal() return block_diag.mT.contiguous().view(*block_diag.shape[:-2], -1) @@ -86,15 +87,15 @@ def _remove_batch_dim(self, other): return other def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_decomposition()) def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_inv_decomposition(initial_vectors)) def _size(self) -> torch.Size: @@ -105,15 +106,15 @@ def _size(self) -> torch.Size: return torch.Size(shape) def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: if num_tridiag: @@ -125,13 +126,13 @@ def _solve( return res def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if inv_quad_rhs is not None: inv_quad_rhs = self._add_batch_dim(inv_quad_rhs) diff --git a/linear_operator/operators/block_linear_operator.py b/linear_operator/operators/block_linear_operator.py index 54862816..93903ef3 100644 --- a/linear_operator/operators/block_linear_operator.py +++ b/linear_operator/operators/block_linear_operator.py @@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -60,8 +59,8 @@ def _add_batch_dim(self, other): raise NotImplementedError def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) batch_shape = torch.Size((*batch_shape, self.base_linear_op.size(-3))) res = self.__class__(self.base_linear_op._expand_batch(batch_shape)) return res @@ -103,9 +102,9 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return self.__class__(new_base_linear_op, block_dim=-3) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) isvector = rhs.ndimension() == 1 if isvector: rhs = rhs.unsqueeze(1) @@ -151,15 +150,17 @@ def _remove_batch_dim(self, other): raise NotImplementedError def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) # We're using a custom method here - the constant mul is applied to the base_lazy tensor # This preserves the block structure from linear_operator.operators.constant_mul_linear_operator import ConstantMulLinearOperator return self.__class__(ConstantMulLinearOperator(self.base_linear_op, other)) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) base_op = self.base_linear_op if isinstance(base_op, LinearOperator): new_base_op = base_op._transpose_nonbatch() @@ -168,8 +169,8 @@ def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[Line return self.__class__(new_base_op) def zero_mean_mvn_samples( - self: Float[LinearOperator, "*batch N N"], num_samples: int - ) -> Float[Tensor, "num_samples *batch N"]: + self: LinearOperator, num_samples: int # shape: (*batch, N, N) + ) -> Tensor: # shape: (num_samples, *batch, N) res = self.base_linear_op.zero_mean_mvn_samples(num_samples) res = self._remove_batch_dim(res.unsqueeze(-1)).squeeze(-1) return res diff --git a/linear_operator/operators/cat_linear_operator.py b/linear_operator/operators/cat_linear_operator.py index 521d8e02..2e8974a4 100644 --- a/linear_operator/operators/cat_linear_operator.py +++ b/linear_operator/operators/cat_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator, to_dense @@ -134,7 +133,9 @@ def _split_slice(self, slice_idx: slice) -> Tuple[Sequence[int], List[slice]]: [first_slice] + [_noop_index] * num_middle_tensors + [last_slice], ) - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) if self.cat_dim == -2: res = [] curr_col = 0 @@ -160,8 +161,8 @@ def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, ".. return res def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) batch_dim = self.cat_dim + 2 if batch_dim < 0: if batch_shape[batch_dim] != self.batch_shape[batch_dim]: @@ -304,9 +305,9 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return res def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) output_device = self.device if self.device is not None else rhs.device # make a copy of `rhs` on each device rhs_ = [] @@ -361,7 +362,9 @@ def _permute_batch(self, *dims: int) -> LinearOperator: def _size(self) -> torch.Size: return self._shape - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) if self.cat_dim == -2: new_dim = -1 elif self.cat_dim == -1: @@ -380,17 +383,19 @@ def _unsqueeze_batch(self, dim: int) -> LinearOperator: ) return res - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return torch.cat([to_dense(L) for L in self.linear_ops], dim=self.cat_dim) def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: res = super().inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad) return tuple(r.to(self.device) for r in res) @@ -407,7 +412,11 @@ def devices(self) -> List[torch.device]: def device_count(self) -> int: return len(set(self.devices)) - def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]: + def to( + self: LinearOperator, # shape: (*batch, M, N) + *args, + **kwargs, + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a new CatLinearOperator with device as the output_device and dtype as the dtype. diff --git a/linear_operator/operators/chol_linear_operator.py b/linear_operator/operators/chol_linear_operator.py index 6b502d91..3be5dd92 100644 --- a/linear_operator/operators/chol_linear_operator.py +++ b/linear_operator/operators/chol_linear_operator.py @@ -6,7 +6,6 @@ from typing import Callable, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import LinearOperator @@ -30,7 +29,7 @@ class CholLinearOperator(RootLinearOperator): (i.e. :math:`\mathbf L \mathbf L^\top`). """ - def __init__(self, chol: Float[_TriangularLinearOperatorBase, "*#batch N N"], upper: bool = False): + def __init__(self, chol: _TriangularLinearOperatorBase, upper: bool = False): if not isinstance(chol, _TriangularLinearOperatorBase): warnings.warn( "chol argument to CholLinearOperator should be a TriangularLinearOperator. " @@ -47,33 +46,37 @@ def __init__(self, chol: Float[_TriangularLinearOperatorBase, "*#batch N N"], up self.upper = upper @property - def _chol_diag(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "... N"]: + def _chol_diag( + self: LinearOperator, # shape: (*batch, N, N) + ) -> torch.Tensor: # shape: (..., N) return self.root._diagonal() @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) if upper == self.upper: return self.root else: return self.root._transpose_nonbatch() @cached - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) # TODO: Can we be smarter here? return (self.root.to_dense() ** 2).sum(-1) def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: if num_tridiag: @@ -81,7 +84,9 @@ def _solve( return self.root._cholesky_solve(rhs, upper=self.upper) @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) root = self.root if self.upper: res = root._transpose_nonbatch() @ root @@ -90,7 +95,9 @@ def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch return res.to_dense() @cached - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) """ Returns the inverse of the CholLinearOperator. """ @@ -98,10 +105,10 @@ def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, return CholLinearOperator(TriangularLinearOperator(Linv, upper=not self.upper), upper=not self.upper) def inv_quad( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]], + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Tensor, # shape: (*batch, N, M) or (*batch, N) reduce_inv_quad: bool = True, - ) -> Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"]]: + ) -> Tensor: # shape: (*batch, M) or (*batch) if self.upper: R = self.root._transpose_nonbatch().solve(inv_quad_rhs) else: @@ -112,13 +119,13 @@ def inv_quad( return inv_quad_term def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if not self.is_square: raise RuntimeError( @@ -158,19 +165,19 @@ def inv_quad_logdet( return inv_quad_term, logdet_term def root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, method: Optional[str] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) inv_root = self.root.inverse() return RootLinearOperator(inv_root._transpose_nonbatch()) def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) is_vector = right_tensor.ndim == 1 if is_vector: right_tensor = right_tensor.unsqueeze(-1) diff --git a/linear_operator/operators/constant_mul_linear_operator.py b/linear_operator/operators/constant_mul_linear_operator.py index 44215365..994e60e5 100644 --- a/linear_operator/operators/constant_mul_linear_operator.py +++ b/linear_operator/operators/constant_mul_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -71,17 +70,21 @@ def __init__(self, base_linear_op, constant): self.base_linear_op = base_linear_op self._constant = constant - def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]: + def _approx_diagonal( + self: LinearOperator, # shape: (*batch, N, N) + ) -> torch.Tensor: # shape: (*batch, N) res = self.base_linear_op._approx_diagonal() return res * self._constant.unsqueeze(-1) - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) res = self.base_linear_op._diagonal() return res * self._constant.unsqueeze(-1) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.base_linear_op._expand_batch(batch_shape), self._constant.expand(*batch_shape) if len(batch_shape) else self._constant, @@ -108,9 +111,9 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return type(self)(base_linear_op=base_linear_op, constant=constant) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) res = self.base_linear_op._matmul(rhs) res = res * self.expanded_constant return res @@ -140,14 +143,16 @@ def _size(self) -> torch.Size: return self.base_linear_op.size() def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) res = self.base_linear_op._t_matmul(rhs) res = res * self.expanded_constant return res - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return ConstantMulLinearOperator(self.base_linear_op._transpose_nonbatch(), self._constant) def _unsqueeze_batch(self, dim: int) -> LinearOperator: @@ -171,14 +176,16 @@ def expanded_constant(self) -> Tensor: return constant @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) res = self.base_linear_op.to_dense() return res * self.expanded_constant @cached(name="root_decomposition") def root_decomposition( - self: Float[LinearOperator, "*batch N N"], method: Optional[str] = None - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) if torch.all(self._constant >= 0): base_root = self.base_linear_op.root_decomposition(method=method).root return RootLinearOperator(ConstantMulLinearOperator(base_root, self._constant**0.5)) diff --git a/linear_operator/operators/dense_linear_operator.py b/linear_operator/operators/dense_linear_operator.py index a55db844..739a5baf 100644 --- a/linear_operator/operators/dense_linear_operator.py +++ b/linear_operator/operators/dense_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator, to_dense @@ -31,18 +30,20 @@ def __init__(self, tsr): self.tensor = tsr def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) return torch.cholesky_solve(rhs, self.to_dense(), upper=upper) - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) return self.tensor.diagonal(dim1=-1, dim2=-2) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.tensor.expand(*batch_shape, *self.matrix_shape)) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: @@ -59,9 +60,9 @@ def _isclose(self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: b return torch.isclose(self.tensor, to_dense(other), rtol=rtol, atol=atol, equal_nan=equal_nan) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return torch.matmul(self.tensor, rhs) def _prod_batch(self, dim: int) -> LinearOperator: @@ -77,22 +78,26 @@ def _size(self) -> torch.Size: def _sum_batch(self, dim: int) -> LinearOperator: return self.__class__(self.tensor.sum(dim)) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return DenseLinearOperator(self.tensor.mT) def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) return torch.matmul(self.tensor.mT, rhs) - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return self.tensor def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) if isinstance(other, DenseLinearOperator): return DenseLinearOperator(self.tensor + other.tensor) elif isinstance(other, torch.Tensor): diff --git a/linear_operator/operators/diag_linear_operator.py b/linear_operator/operators/diag_linear_operator.py index 957626d9..03c485f0 100644 --- a/linear_operator/operators/diag_linear_operator.py +++ b/linear_operator/operators/diag_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator import settings @@ -23,14 +22,14 @@ class DiagLinearOperator(TriangularLinearOperator): :param diag: Diagonal elements of LinearOperator. """ - def __init__(self, diag: Float[Tensor, "*#batch N"]): + def __init__(self, diag: Tensor): super(TriangularLinearOperator, self).__init__(diag) self._diag = diag def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) if isinstance(other, DiagLinearOperator): return self.add_diagonal(other._diag) from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator @@ -49,23 +48,25 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O @cached(name="cholesky", ignore_args=True) def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) return self.sqrt() def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) return rhs / self._diag.unsqueeze(-1).pow(2) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self._diag.expand(*batch_shape, self._diag.size(-1))) - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) return self._diag def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: @@ -79,29 +80,29 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice return res def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return self.__class__(self._diag * other.unsqueeze(-1)) def _mul_matrix( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[torch.Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"]], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + ) -> LinearOperator: # shape: (..., M, N) return DiagLinearOperator(self._diag * other._diagonal()) def _prod_batch(self, dim: int) -> LinearOperator: return self.__class__(self._diag.prod(dim)) def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) return self.sqrt() def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) return self.inverse().sqrt() def _size(self) -> torch.Size: @@ -111,13 +112,15 @@ def _sum_batch(self, dim: int) -> LinearOperator: return self.__class__(self._diag.sum(dim)) def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) # Diagonal matrices always commute return self._matmul(rhs) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self def abs(self) -> LinearOperator: @@ -127,38 +130,44 @@ def abs(self) -> LinearOperator: return self.__class__(self._diag.abs()) def add_diagonal( - self: Float[LinearOperator, "*batch N N"], - diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, # shape: (*batch, N, N) + diag: torch.Tensor, # shape: (..., N) or (..., 1) or () + ) -> LinearOperator: # shape: (*batch, N, N) shape = torch.broadcast_shapes(self._diag.shape, diag.shape) return DiagLinearOperator(self._diag.expand(shape) + diag.expand(shape)) @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) if self._diag.dim() == 0: return self._diag return torch.diag_embed(self._diag) - def exp(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def exp( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a DiagLinearOperator with all diagonal entries exponentiated. """ return self.__class__(self._diag.exp()) - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) """ Returns the inverse of the DiagLinearOperator. """ return self.__class__(self._diag.reciprocal()) def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: # TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append) if inv_quad_rhs is None: @@ -183,7 +192,9 @@ def inv_quad_logdet( return inv_quad_term, logdet_term - def log(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def log( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a DiagLinearOperator with the log of all diagonal entries. """ @@ -192,9 +203,9 @@ def log(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*ba # this needs to be the public "matmul", instead of "_matmul", to hit the special cases before # a MatmulLinearOperator is created. def matmul( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], - ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) if isinstance(other, Tensor): diag = self._diag if other.ndim == 1 else self._diag.unsqueeze(-1) return diag * other @@ -218,16 +229,16 @@ def matmul( return super().matmul(other) # happens with other structured linear operators def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self.matmul(rhs) def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) res = self.inverse()._matmul(right_tensor) if left_tensor is not None: res = left_tensor @ res @@ -243,17 +254,19 @@ def solve_triangular( return rhs return self.solve(right_tensor=rhs) - def sqrt(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def sqrt( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a DiagLinearOperator with the square root of all diagonal entries. """ return self.__class__(self._diag.sqrt()) def sqrt_inv_matmul( - self: Float[LinearOperator, "*batch N N"], - rhs: Float[Tensor, "*batch N P"], - lhs: Optional[Float[Tensor, "*batch O N"]] = None, - ) -> Union[Float[Tensor, "*batch N P"], Tuple[Float[Tensor, "*batch O P"], Float[Tensor, "*batch O"]]]: + self: LinearOperator, # shape: (*batch, N, N) + rhs: Tensor, # shape: (*batch, N, P) + lhs: Optional[Tensor] = None, # shape: (*batch, O, N) + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) matrix_inv_root = self._root_inv_decomposition() if lhs is None: return matrix_inv_root.matmul(rhs) @@ -263,15 +276,15 @@ def sqrt_inv_matmul( return sqrt_inv_matmul, inv_quad def zero_mean_mvn_samples( - self: Float[LinearOperator, "*batch N N"], num_samples: int - ) -> Float[Tensor, "num_samples *batch N"]: + self: LinearOperator, num_samples: int # shape: (*batch, N, N) + ) -> Tensor: # shape: (num_samples, *batch, N) base_samples = torch.randn(num_samples, *self._diag.shape, dtype=self.dtype, device=self.device) return base_samples * self._diag.sqrt() @cached(name="svd") def _svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) evals, evecs = self._symeig(eigenvectors=True) S = torch.abs(evals) U = evecs @@ -279,10 +292,10 @@ def _svd( return U, S, V def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) evals = self._diag if eigenvectors: diag_values = torch.ones(evals.shape[:-1], device=evals.device, dtype=evals.dtype).unsqueeze(-1) @@ -314,9 +327,9 @@ def __init__(self, diag_values: torch.Tensor, diag_shape: int): self.diag_shape = diag_shape def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) if isinstance(other, ConstantDiagLinearOperator): if other.shape[-1] == self.shape[-1]: return ConstantDiagLinearOperator(self.diag_values + other.diag_values, self.diag_shape) @@ -336,23 +349,25 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O return (res,) @property - def _diag(self: Float[LinearOperator, "... N N"]) -> Float[Tensor, "... N"]: + def _diag( + self: LinearOperator, # shape: (..., N, N) + ) -> Tensor: # shape: (..., N) return self.diag_values.expand(*self.diag_values.shape[:-1], self.diag_shape) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.diag_values.expand(*batch_shape, 1), diag_shape=self.diag_shape) def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return self.__class__(self.diag_values * other, diag_shape=self.diag_shape) def _mul_matrix( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[torch.Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"]], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + ) -> LinearOperator: # shape: (..., M, N) if isinstance(other, ConstantDiagLinearOperator): if not self.diag_shape == other.diag_shape: raise ValueError( @@ -378,28 +393,34 @@ def abs(self) -> LinearOperator: """ return ConstantDiagLinearOperator(self.diag_values.abs(), diag_shape=self.diag_shape) - def exp(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def exp( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a DiagLinearOperator with all diagonal entries exponentiated. """ return ConstantDiagLinearOperator(self.diag_values.exp(), diag_shape=self.diag_shape) - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) """ Returns the inverse of the DiagLinearOperator. """ return ConstantDiagLinearOperator(self.diag_values.reciprocal(), diag_shape=self.diag_shape) - def log(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def log( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a DiagLinearOperator with the log of all diagonal entries. """ return ConstantDiagLinearOperator(self.diag_values.log(), diag_shape=self.diag_shape) def matmul( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], - ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) if isinstance(other, ConstantDiagLinearOperator): return self._mul_matrix(other) return super().matmul(other) @@ -409,7 +430,9 @@ def solve_triangular( ) -> torch.Tensor: return rhs / self.diag_values - def sqrt(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def sqrt( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a DiagLinearOperator with the square root of all diagonal entries. """ diff --git a/linear_operator/operators/identity_linear_operator.py b/linear_operator/operators/identity_linear_operator.py index 435e7cf0..1d9b75c9 100644 --- a/linear_operator/operators/identity_linear_operator.py +++ b/linear_operator/operators/identity_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -63,20 +62,20 @@ def _maybe_reshape_rhs(self, rhs: Union[torch.Tensor, LinearOperator]) -> Union[ @cached(name="cholesky", ignore_args=True) def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) return self def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) return self._maybe_reshape_rhs(rhs) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return IdentityLinearOperator( diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device ) @@ -96,20 +95,20 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return super()._getitem(row_index, col_index, *batch_indices) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self._maybe_reshape_rhs(rhs) def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return ConstantDiagLinearOperator(self.diag_values * other, diag_shape=self.diag_shape) def _mul_matrix( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[torch.Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"]], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + ) -> LinearOperator: # shape: (..., M, N) return other def _permute_batch(self, *dims: int) -> LinearOperator: @@ -126,15 +125,15 @@ def _prod_batch(self, dim: int) -> LinearOperator: ) def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) return self.sqrt() def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) return self.inverse().sqrt() def _size(self) -> torch.Size: @@ -142,24 +141,26 @@ def _size(self) -> torch.Size: @cached(name="svd") def _svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) return self, self._diag, self def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) return self._diag, self def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) return self._maybe_reshape_rhs(rhs) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self def _unsqueeze_batch(self, dim: int) -> LinearOperator: @@ -173,20 +174,24 @@ def _unsqueeze_batch(self, dim: int) -> LinearOperator: def abs(self) -> LinearOperator: return self - def exp(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def exp( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return self - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) return self def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: # TODO: Use proper batching for inv_quad_rhs (prepand to shape rather than append) if inv_quad_rhs is None: @@ -204,15 +209,17 @@ def inv_quad_logdet( return inv_quad_term, logdet_term - def log(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def log( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return ZeroLinearOperator( *self._batch_shape, self.diag_shape, self.diag_shape, dtype=self._dtype, device=self._device ) def matmul( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], - ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) is_vec = False if other.dim() == 1: is_vec = True @@ -223,23 +230,25 @@ def matmul( return res def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) res = self._maybe_reshape_rhs(right_tensor) if left_tensor is not None: res = left_tensor @ res return res - def sqrt(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def sqrt( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return self def sqrt_inv_matmul( - self: Float[LinearOperator, "*batch N N"], - rhs: Float[Tensor, "*batch N P"], - lhs: Optional[Float[Tensor, "*batch O N"]] = None, - ) -> Union[Float[Tensor, "*batch N P"], Tuple[Float[Tensor, "*batch O P"], Float[Tensor, "*batch O"]]]: + self: LinearOperator, # shape: (*batch, N, N) + rhs: Tensor, # shape: (*batch, N, P) + lhs: Optional[Tensor] = None, # shape: (*batch, O, N) + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) if lhs is None: return self._maybe_reshape_rhs(rhs) else: @@ -253,12 +262,16 @@ def type(self: LinearOperator, dtype: torch.dtype) -> LinearOperator: ) def zero_mean_mvn_samples( - self: Float[LinearOperator, "*batch N N"], num_samples: int - ) -> Float[Tensor, "num_samples *batch N"]: + self: LinearOperator, num_samples: int # shape: (*batch, N, N) + ) -> Tensor: # shape: (num_samples, *batch, N) base_samples = torch.randn(num_samples, *self.shape[:-1], dtype=self.dtype, device=self.device) return base_samples - def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]: + def to( + self: LinearOperator, # shape: (*batch, M, N) + *args, + **kwargs, + ) -> LinearOperator: # shape: (*batch, M, N) # Overwrite the to() method in _linear_operator to also convert the dtype and device saved in _kwargs. diff --git a/linear_operator/operators/interpolated_linear_operator.py b/linear_operator/operators/interpolated_linear_operator.py index d088eb25..89455dde 100644 --- a/linear_operator/operators/interpolated_linear_operator.py +++ b/linear_operator/operators/interpolated_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -94,14 +93,18 @@ def __init__( self.right_interp_indices = right_interp_indices self.right_interp_values = right_interp_values - def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]: + def _approx_diagonal( + self: LinearOperator, # shape: (*batch, N, N) + ) -> torch.Tensor: # shape: (*batch, N) base_diag_root = self.base_linear_op._diagonal().sqrt() left_res = left_interp(self.left_interp_indices, self.left_interp_values, base_diag_root.unsqueeze(-1)) right_res = left_interp(self.right_interp_indices, self.right_interp_values, base_diag_root.unsqueeze(-1)) res = left_res * right_res return res.squeeze(-1) - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) if isinstance(self.base_linear_op, RootLinearOperator) and isinstance( self.base_linear_op.root, DenseLinearOperator ): @@ -116,8 +119,8 @@ def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, ".. return super(InterpolatedLinearOperator, self)._diagonal() def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.base_linear_op._expand_batch(batch_shape), self.left_interp_indices.expand(*batch_shape, *self.left_interp_indices.shape[-2:]), @@ -189,9 +192,9 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return res def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) # Get sparse tensor representations of left/right interp matrices left_interp_t = self._sparse_left_interp_t(self.left_interp_indices, self.left_interp_values) right_interp_t = self._sparse_right_interp_t(self.right_interp_indices, self.right_interp_values) @@ -218,8 +221,8 @@ def _matmul( return res def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) # We're using a custom method here - the constant mul is applied to the base_lazy tensor # This preserves the interpolated structure return self.__class__( @@ -231,9 +234,9 @@ def _mul_constant( ) def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) # Get sparse tensor representations of left/right interp matrices left_interp_t = self._sparse_left_interp_t(self.left_interp_indices, self.left_interp_values) right_interp_t = self._sparse_right_interp_t(self.right_interp_indices, self.right_interp_values) @@ -331,7 +334,9 @@ def _size(self) -> torch.Size: self.base_linear_op.batch_shape + (self.left_interp_indices.size(-2), self.right_interp_indices.size(-2)) ) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) res = self.__class__( self.base_linear_op.mT, self.right_interp_indices, @@ -408,9 +413,9 @@ def _sum_batch(self, dim: int) -> LinearOperator: ) def matmul( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], - ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) # We're using a custom matmul here, because it is significantly faster than # what we get from the function factory. # The _matmul_closure is optimized for repeated calls, such as for _solve @@ -448,8 +453,8 @@ def matmul( return res def zero_mean_mvn_samples( - self: Float[LinearOperator, "*batch N N"], num_samples: int - ) -> Float[Tensor, "num_samples *batch N"]: + self: LinearOperator, num_samples: int # shape: (*batch, N, N) + ) -> Tensor: # shape: (num_samples, *batch, N) base_samples = self.base_linear_op.zero_mean_mvn_samples(num_samples) batch_iter = tuple(range(1, base_samples.dim())) base_samples = base_samples.permute(*batch_iter, 0) @@ -457,7 +462,11 @@ def zero_mean_mvn_samples( batch_iter = tuple(range(res.dim() - 1)) return res.permute(-1, *batch_iter).contiguous() - def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]: + def to( + self: LinearOperator, # shape: (*batch, M, N) + *args, + **kwargs, + ) -> LinearOperator: # shape: (*batch, M, N) # Overwrite the to() method in _linear_operator to avoid converting index matrices to float. # Will only convert both dtype and device when arg and dtype are both int/float. diff --git a/linear_operator/operators/keops_linear_operator.py b/linear_operator/operators/keops_linear_operator.py index 9fd0cd83..76af001b 100644 --- a/linear_operator/operators/keops_linear_operator.py +++ b/linear_operator/operators/keops_linear_operator.py @@ -5,7 +5,6 @@ from typing import Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -28,7 +27,9 @@ def __init__(self, x1, x2, covar_func, **params): self.params = params @cached(name="kernel_diag") - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) """ Explicitly compute kernel diag via covar_func when it is needed rather than relying on lazy tensor ops. """ @@ -40,15 +41,17 @@ def covar_mat(self): return self.covar_func(self.x1, self.x2, **self.params) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self.covar_mat @ rhs.contiguous() def _size(self) -> torch.Size: return torch.Size(self.covar_mat.shape) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return KeOpsLinearOperator(self.x2, self.x1, self.covar_func) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: diff --git a/linear_operator/operators/kernel_linear_operator.py b/linear_operator/operators/kernel_linear_operator.py index 0d1c009d..947acc6d 100644 --- a/linear_operator/operators/kernel_linear_operator.py +++ b/linear_operator/operators/kernel_linear_operator.py @@ -3,7 +3,6 @@ import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import LinearOperator, to_dense @@ -132,9 +131,9 @@ def _covar_func(x1, x2, lengthscale, outputscale): def __init__( self, - x1: Float[Tensor, "... M D"], - x2: Float[Tensor, "... N D"], - covar_func: Callable[..., Float[Union[Tensor, LinearOperator], "... M N"]], + x1: Tensor, # shape: (..., M, D) + x2: Tensor, # shape: (..., N, D) + covar_func: Callable[..., Union[Tensor, LinearOperator]], # shape: (..., M, N) num_outputs_per_input: Tuple[int, int] = (1, 1), num_nonbatch_dimensions: Optional[Dict[str, int]] = None, **params: Union[Tensor, Any], @@ -223,7 +222,9 @@ def __init__( self.num_nonbatch_dimensions = num_nonbatch_dimensions @cached(name="kernel_diag") - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) # Explicitly compute kernel diag via covar_func when it is needed rather than relying on lazy tensor ops. # We will do this by shoving all of the data into a batch dimension (i.e. compute a N x ... x 1 x 1 kernel # or a N x ... x num_outs-per_in x num_outs_per_in kernel) @@ -248,7 +249,9 @@ def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, ".. @property @cached(name="covar_mat") - def covar_mat(self: Float[LinearOperator, "... M N"]) -> Float[Union[Tensor, LinearOperator], "... M N"]: + def covar_mat( + self: LinearOperator, # shape: (..., M, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, N) return self.covar_func(self.x1, self.x2, **self.tensor_params, **self.nontensor_params) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: @@ -368,9 +371,9 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I ) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self.covar_mat @ rhs.contiguous() def _permute_batch(self, *dims: int) -> LinearOperator: @@ -400,7 +403,9 @@ def _size(self) -> torch.Size: ] ) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__( self.x2, self.x1, diff --git a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py index d4a166b2..5e703c58 100644 --- a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py +++ b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py @@ -3,7 +3,6 @@ from typing import Callable, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator import settings @@ -64,13 +63,13 @@ def __init__(self, *linear_ops, preconditioner_override=None): self._diag_is_constant = isinstance(self.diag_tensor, ConstantDiagLinearOperator) def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if inv_quad_rhs is not None: inv_quad_term, _ = super().inv_quad_logdet( @@ -81,7 +80,9 @@ def inv_quad_logdet( logdet_term = self._logdet() if logdet else None return inv_quad_term, logdet_term - def _logdet(self: Float[LinearOperator, "*batch N N"]) -> Float[Tensor, " *batch"]: + def _logdet( + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tensor: # shape: (*batch) if self._diag_is_constant: # symeig requires computing the eigenvectors for it to be differentiable evals, _ = self.linear_op._symeig(eigenvectors=True) @@ -130,15 +131,15 @@ def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], return None, None, None def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: @@ -222,8 +223,8 @@ def _solve( return super()._solve(rhs, preconditioner=preconditioner, num_tridiag=num_tridiag) def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) if self._diag_is_constant: evals, q_matrix = self.linear_op.diagonalization() updated_evals = DiagLinearOperator((evals + self.diag_tensor._diagonal()).pow(0.5)) @@ -255,10 +256,10 @@ def _root_decomposition( return super()._root_decomposition() def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) if self._diag_is_constant: evals, q_matrix = self.linear_op.diagonalization() inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor._diagonal()).pow(-0.5)) @@ -290,10 +291,10 @@ def _root_inv_decomposition( return super()._root_inv_decomposition(initial_vectors=initial_vectors) def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) # return_evals_as_lazy is a flag to return the eigenvalues as a lazy tensor # which is useful for root decompositions here (see the root_decomposition # method above) @@ -305,9 +306,9 @@ def _symeig( return super()._symeig(eigenvectors=eigenvectors) def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) if isinstance(other, ConstantDiagLinearOperator) and self._diag_is_constant: # the other cases have only partial implementations return KroneckerProductAddedDiagLinearOperator(self.linear_op, self.diag_tensor + other) diff --git a/linear_operator/operators/kronecker_product_linear_operator.py b/linear_operator/operators/kronecker_product_linear_operator.py index 1cece8b0..3a98c4a5 100644 --- a/linear_operator/operators/kronecker_product_linear_operator.py +++ b/linear_operator/operators/kronecker_product_linear_operator.py @@ -5,7 +5,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator import settings @@ -68,7 +67,7 @@ class KroneckerProductLinearOperator(LinearOperator): :param linear_ops: :math:`\boldsymbol K_1, \ldots, \boldsymbol K_P`: the LinearOperators in the Kronecker product. """ - def __init__(self, *linear_ops: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"]]): + def __init__(self, *linear_ops: Union[Tensor, LinearOperator]): try: linear_ops = tuple(to_linear_operator(linear_op) for linear_op in linear_ops) except TypeError: @@ -96,9 +95,9 @@ def __init__(self, *linear_ops: Union[Float[Tensor, "... #M #N"], Float[LinearOp self.linear_ops = linear_ops def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) if isinstance(other, (KroneckerProductDiagLinearOperator, ConstantDiagLinearOperator)): from linear_operator.operators.kronecker_product_added_diag_linear_operator import ( KroneckerProductAddedDiagLinearOperator, @@ -114,9 +113,9 @@ def __add__( return super().__add__(other) def add_diagonal( - self: Float[LinearOperator, "*batch N N"], - diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, # shape: (*batch, N, N) + diag: torch.Tensor, # shape: (..., N) or (..., 1) or () + ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.kronecker_product_added_diag_linear_operator import ( KroneckerProductAddedDiagLinearOperator, ) @@ -145,27 +144,29 @@ def add_diagonal( return KroneckerProductAddedDiagLinearOperator(self, diag_tensor) def diagonalization( - self: Float[LinearOperator, "*batch N N"], method: Optional[str] = None - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) if method is None: method = "symeig" return super().diagonalization(method=method) @cached - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) # here we use that (A \kron B)^-1 = A^-1 \kron B^-1 # TODO: Investigate under what conditions computing individual individual inverses makes sense inverses = [lt.inverse() for lt in self.linear_ops] return self.__class__(*inverses) def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if inv_quad_rhs is not None: inv_quad_term, _ = super().inv_quad_logdet( @@ -178,17 +179,19 @@ def inv_quad_logdet( @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) chol_factors = [lt.cholesky(upper=upper) for lt in self.linear_ops] return KroneckerProductTriangularLinearOperator(*chol_factors, upper=upper) - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) return _kron_diag(*self.linear_ops) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__(*[linear_op._expand_batch(batch_shape) for linear_op in self.linear_ops]) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: @@ -212,15 +215,15 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice return res def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: # Computes solve by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1 @@ -258,15 +261,17 @@ def _inv_matmul(self, right_tensor, left_tensor=None): res = left_tensor @ res return res - def _logdet(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, " *batch"]: + def _logdet( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch) evals, _ = self.diagonalization() logdet = evals.clamp(min=1e-7).log().sum(-1) return logdet def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) is_vec = rhs.ndimension() == 1 if is_vec: rhs = rhs.unsqueeze(-1) @@ -279,8 +284,8 @@ def _matmul( @cached(name="root_decomposition") def root_decomposition( - self: Float[LinearOperator, "*batch N N"], method: Optional[str] = None - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators import RootLinearOperator # return a dense root decomposition if the matrix is small @@ -293,11 +298,11 @@ def root_decomposition( @cached(name="root_inv_decomposition") def root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, method: Optional[str] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) from linear_operator.operators import RootLinearOperator # return a dense root decomposition if the matrix is small @@ -316,8 +321,8 @@ def _size(self) -> torch.Size: @cached(name="svd") def _svd( - self: Float[LinearOperator, "*batch N N"] - ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) U, S, V = [], [], [] for lt in self.linear_ops: U_, S_, V_ = lt.svd() @@ -330,10 +335,10 @@ def _svd( return U, S, V def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) # return_evals_as_lazy is a flag to return the eigenvalues as a lazy tensor # which is useful for root decompositions here (see the root_decomposition # method above) @@ -354,9 +359,9 @@ def _symeig( return evals, evecs def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) is_vec = rhs.ndimension() == 1 if is_vec: rhs = rhs.unsqueeze(-1) @@ -367,7 +372,9 @@ def _t_matmul( res = res.squeeze(-1) return res - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__(*(linear_op._transpose_nonbatch() for linear_op in self.linear_ops), **self._kwargs) @@ -381,22 +388,24 @@ def __init__(self, *linear_ops, upper=False): self.upper = upper @cached - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) # here we use that (A \kron B)^-1 = A^-1 \kron B^-1 inverses = [lt.inverse() for lt in self.linear_ops] return self.__class__(*inverses, upper=self.upper) @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) raise NotImplementedError("_cholesky not applicable to triangular lazy tensors") def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) if upper: # res = (U.T @ U)^-1 @ v = U^-1 @ U^-T @ v w = self._transpose_nonbatch().solve(rhs) @@ -408,17 +417,17 @@ def _cholesky_solve( return res def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) raise NotImplementedError("_symeig not applicable to triangular lazy tensors") def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) # For triangular components, using triangular-triangular substition should generally be good return self._inv_matmul(right_tensor=right_tensor, left_tensor=left_tensor) @@ -443,23 +452,25 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O @cached(name="cholesky") def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) chol_factors = [lt.cholesky(upper=upper) for lt in self.linear_ops] return KroneckerProductDiagLinearOperator(*chol_factors) @property - def _diag(self: Float[LinearOperator, "*batch N N"]) -> Float[Tensor, "*batch N"]: + def _diag( + self: LinearOperator, # shape: (*batch, N, N) + ) -> Tensor: # shape: (*batch, N) return _kron_diag(*self.linear_ops) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return KroneckerProductTriangularLinearOperator._expand_batch(self, batch_shape) def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return DiagLinearOperator(self._diag * other.unsqueeze(-1)) def _size(self) -> torch.Size: @@ -472,10 +483,10 @@ def _size(self) -> torch.Size: return torch.Size([*batch_shape, N, N]) def _symeig( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: + ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) # return_evals_as_lazy is a flag to return the eigenvalues as a lazy tensor # which is useful for root decompositions here (see the root_decomposition # method above) @@ -501,11 +512,15 @@ def abs(self) -> LinearOperator: """ return self.__class__(*[lt.abs() for lt in self.linear_ops]) - def exp(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def exp( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) raise NotImplementedError(f"torch.exp({self.__class__.__name__}) is not implemented.") @cached - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) """ Returns the inverse of the DiagLinearOperator. """ @@ -513,10 +528,14 @@ def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, inverses = [lt.inverse() for lt in self.linear_ops] return self.__class__(*inverses) - def log(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def log( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) raise NotImplementedError(f"torch.log({self.__class__.__name__}) is not implemented.") - def sqrt(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def sqrt( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a DiagLinearOperator with the square root of all diagonal entries. """ diff --git a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py index 6e7f1aba..705552f3 100644 --- a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py +++ b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py @@ -2,7 +2,6 @@ from typing import Callable, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators import to_dense @@ -46,8 +45,8 @@ def chol_cap_mat(self): return chol_cap_mat def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) # We have to over-ride this here for the case where the constant is negative if other > 0: res = super()._mul_constant(other) @@ -59,15 +58,15 @@ def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], return None, None, None def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: A_inv = self._diag_tensor.inverse() # This is fine since it's a DiagLinearOperator @@ -98,9 +97,9 @@ def _logdet(self): return logdet_term def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): @@ -109,13 +108,13 @@ def __add__( return AddedDiagLinearOperator(self._linear_op + other, self._diag_tensor) def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if not self.is_square: raise RuntimeError( @@ -157,10 +156,10 @@ def inv_quad_logdet( return inv_quad_term, logdet_term def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) if not self.is_square: raise RuntimeError( "solve only operates on (batches of) square (positive semi-definite) LinearOperators. " diff --git a/linear_operator/operators/low_rank_root_linear_operator.py b/linear_operator/operators/low_rank_root_linear_operator.py index 3cd39e2c..8f5f0c84 100644 --- a/linear_operator/operators/low_rank_root_linear_operator.py +++ b/linear_operator/operators/low_rank_root_linear_operator.py @@ -2,7 +2,6 @@ from typing import Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import LinearOperator @@ -19,9 +18,9 @@ class LowRankRootLinearOperator(RootLinearOperator): """ def add_diagonal( - self: Float[LinearOperator, "*batch N N"], - diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, # shape: (*batch, N, N) + diag: torch.Tensor, # shape: (..., N) or (..., 1) or () + ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator from linear_operator.operators.low_rank_root_added_diag_linear_operator import ( LowRankRootAddedDiagLinearOperator, @@ -51,9 +50,9 @@ def add_diagonal( return LowRankRootAddedDiagLinearOperator(self, diag_tensor) def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator from linear_operator.operators.low_rank_root_added_diag_linear_operator import ( LowRankRootAddedDiagLinearOperator, diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index b9ffd148..9b213152 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -1,7 +1,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Bool, Float from torch import Tensor from linear_operator.operators._linear_operator import _is_noop_index, IndexType, LinearOperator @@ -17,9 +16,9 @@ class MaskedLinearOperator(LinearOperator): def __init__( self, - base: Float[LinearOperator, "*batch M0 N0"], - row_mask: Bool[Tensor, "M0"], - col_mask: Bool[Tensor, "N0"], + base: LinearOperator, # shape: (*batch, M0, N0) + row_mask: Tensor, # shape: (M0) + col_mask: Tensor, # shape: (N0) ): r""" Create a new :obj:`~linear_operator.operators.MaskedLinearOperator` that applies a mask to the rows and columns @@ -36,7 +35,10 @@ def __init__( self.row_eq_col_mask = torch.equal(row_mask, col_mask) @staticmethod - def _expand(tensor: Float[Tensor, "*batch N C"], mask: Bool[Tensor, "N0"]) -> Float[Tensor, "*batch N0 C"]: + def _expand( + tensor: Tensor, # shape: (*batch, N, C) + mask: Tensor, # shape: (N0) + ) -> Tensor: # shape: (*batch, N0, C) res = torch.zeros( *tensor.shape[:-2], mask.size(-1), @@ -48,9 +50,9 @@ def _expand(tensor: Float[Tensor, "*batch N C"], mask: Bool[Tensor, "N0"]) -> Fl return res def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) rhs_expanded = self._expand(rhs, self.col_mask) res_expanded = self.base._matmul(rhs_expanded) res = res_expanded[..., self.row_mask, :] @@ -58,9 +60,9 @@ def _matmul( return res def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) rhs_expanded = self._expand(rhs, self.row_mask) res_expanded = self.base._t_matmul(rhs_expanded) res = res_expanded[..., self.col_mask, :] @@ -71,16 +73,22 @@ def _size(self) -> torch.Size: (*self.base.size()[:-2], torch.count_nonzero(self.row_mask), torch.count_nonzero(self.col_mask)) ) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__(self.base.mT, self.col_mask, self.row_mask) - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) if not self.row_eq_col_mask: raise NotImplementedError() diag = self.base.diagonal() return diag[..., self.row_mask] - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) full_dense = self.base.to_dense() return full_dense[..., self.row_mask, :][..., :, self.col_mask] @@ -90,8 +98,8 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O return self.base._bilinear_derivative(left_vecs, right_vecs) + (None, None) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.base._expand_batch(batch_shape), self.row_mask, self.col_mask) def _unsqueeze_batch(self, dim: int) -> LinearOperator: @@ -114,7 +122,11 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice def _permute_batch(self, *dims: int) -> LinearOperator: return self.__class__(self.base._permute_batch(*dims), self.row_mask, self.col_mask) - def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]: + def to( + self: LinearOperator, # shape: (*batch, M, N) + *args, + **kwargs, + ) -> LinearOperator: # shape: (*batch, M, N) # Overwrite the to() method in _linear_operator to avoid converting mask matrices to float. # Will only convert both dtype and device when arg's dtype is not torch.bool. diff --git a/linear_operator/operators/matmul_linear_operator.py b/linear_operator/operators/matmul_linear_operator.py index 0bd93f21..40c389ca 100644 --- a/linear_operator/operators/matmul_linear_operator.py +++ b/linear_operator/operators/matmul_linear_operator.py @@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -47,8 +46,8 @@ def __init__(self, left_linear_op, right_linear_op): self.right_linear_op = right_linear_op def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.left_linear_op._expand_batch(batch_shape), self.right_linear_op._expand_batch(batch_shape) ) @@ -69,7 +68,9 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice res = (left_tensor * right_tensor).sum(-1) return res - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) if isinstance(self.left_linear_op, DenseLinearOperator) and isinstance( self.right_linear_op, DenseLinearOperator ): @@ -95,15 +96,15 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return res def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self.left_linear_op._matmul(self.right_linear_op._matmul(rhs)) def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) return self.right_linear_op._t_matmul(self.left_linear_op._t_matmul(rhs)) def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: @@ -125,9 +126,13 @@ def _permute_batch(self, *dims: int) -> LinearOperator: def _size(self) -> torch.Size: return _matmul_broadcast_shape(self.left_linear_op.shape, self.right_linear_op.shape) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__(self.right_linear_op._transpose_nonbatch(), self.left_linear_op._transpose_nonbatch()) @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return torch.matmul(self.left_linear_op.to_dense(), self.right_linear_op.to_dense()) diff --git a/linear_operator/operators/mul_linear_operator.py b/linear_operator/operators/mul_linear_operator.py index 9d3dfd60..5371dbf7 100644 --- a/linear_operator/operators/mul_linear_operator.py +++ b/linear_operator/operators/mul_linear_operator.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -38,7 +37,9 @@ def __init__(self, left_linear_op, right_linear_op): self.left_linear_op = left_linear_op self.right_linear_op = right_linear_op - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) res = self.left_linear_op._diagonal() * self.right_linear_op._diagonal() return res @@ -48,9 +49,9 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice return left_res * right_res def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) output_shape = _matmul_broadcast_shape(self.shape, rhs.shape) output_batch_shape = output_shape[:-2] @@ -79,8 +80,8 @@ def _matmul( return res def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) if other > 0: res = self.__class__(self.left_linear_op._mul_constant(other), self.right_linear_op) else: @@ -129,20 +130,24 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O return tuple(list(left_deriv_args) + list(right_deriv_args)) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.left_linear_op._expand_batch(batch_shape), self.right_linear_op._expand_batch(batch_shape) ) @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return self.left_linear_op.to_dense() * self.right_linear_op.to_dense() def _size(self) -> torch.Size: return self.left_linear_op.size() - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) # mul.linear_op only works with symmetric matrices return self diff --git a/linear_operator/operators/permutation_linear_operator.py b/linear_operator/operators/permutation_linear_operator.py index b0bdc3b7..c59e3765 100644 --- a/linear_operator/operators/permutation_linear_operator.py +++ b/linear_operator/operators/permutation_linear_operator.py @@ -1,7 +1,6 @@ from typing import Callable, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import LinearOperator @@ -13,19 +12,21 @@ class AbstractPermutationLinearOperator(LinearOperator): 3) the fact that permutation matrices' transposes are their inverses. """ - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) return self._transpose_nonbatch() def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: self._matmul_check_shape(rhs) @@ -96,9 +97,9 @@ def __init__( super().__init__(perm, inv_perm, validate_args=validate_args) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) # input rhs is guaranteed to be at least two-dimensional due to matmul implementation self._matmul_check_shape(rhs) @@ -131,7 +132,9 @@ def _batch_indexing_helper(self, batch_shape: torch.Size) -> Tuple: def _size(self) -> torch.Size: return torch.Size((*self.perm.shape, self.perm.shape[-1])) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return PermutationLinearOperator(perm=self.inv_perm, inv_perm=self.perm, validate_args=False) def to_sparse(self) -> Tensor: @@ -167,16 +170,18 @@ def __init__(self, m: int): self._dtype = torch.float32 def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) self._matmul_check_shape(rhs) return rhs.unflatten(dim=-2, sizes=(self.m, self.m)).transpose(-3, -2).flatten(start_dim=-3, end_dim=-2) def _size(self) -> torch.Size: return torch.Size((self.n, self.n)) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self @property diff --git a/linear_operator/operators/psd_sum_linear_operator.py b/linear_operator/operators/psd_sum_linear_operator.py index c35ac090..4535e75d 100644 --- a/linear_operator/operators/psd_sum_linear_operator.py +++ b/linear_operator/operators/psd_sum_linear_operator.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import LinearOperator @@ -12,6 +11,6 @@ class PsdSumLinearOperator(SumLinearOperator): """ def zero_mean_mvn_samples( - self: Float[LinearOperator, "*batch N N"], num_samples: int - ) -> Float[Tensor, "num_samples *batch N"]: + self: LinearOperator, num_samples: int # shape: (*batch, N, N) + ) -> Tensor: # shape: (num_samples, *batch, N) return sum(linear_op.zero_mean_mvn_samples(num_samples) for linear_op in self.linear_ops) diff --git a/linear_operator/operators/root_linear_operator.py b/linear_operator/operators/root_linear_operator.py index 50a257ef..bc029224 100644 --- a/linear_operator/operators/root_linear_operator.py +++ b/linear_operator/operators/root_linear_operator.py @@ -2,7 +2,6 @@ from typing import List, Optional, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -20,15 +19,17 @@ def __init__(self, root): super().__init__(root) self.root = root - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) if isinstance(self.root, DenseLinearOperator): return (self.root.tensor**2).sum(-1) else: return super()._diagonal() def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) if len(batch_shape) == 0: return self return self.__class__(self.root._expand_batch(batch_shape)) @@ -65,14 +66,14 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return res def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self.root._matmul(self.root._t_matmul(rhs)) def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) if (other > 0).all(): res = self.__class__(self.root._mul_constant(other.sqrt())) else: @@ -80,30 +81,30 @@ def _mul_constant( return res def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) # Matrix is symmetric return self._matmul(rhs) def add_low_rank( - self: Float[LinearOperator, "*batch N N"], - low_rank_mat: Union[Float[Tensor, "... N _"], Float[LinearOperator, "... N _"]], + self: LinearOperator, # shape: (*batch, N, N) + low_rank_mat: Union[Tensor, LinearOperator], # shape: (..., N, _) root_decomp_method: Optional[str] = None, root_inv_decomp_method: Optional[str] = None, generate_roots: Optional[bool] = True, **root_decomp_kwargs, - ) -> Float[LinearOperator, "*batch N N"]: + ) -> LinearOperator: # shape: (*batch, N, N) return super().add_low_rank(low_rank_mat, root_inv_decomp_method=root_inv_decomp_method) def root_decomposition( - self: Float[LinearOperator, "*batch N N"], method: Optional[str] = None - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) return self def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) return self.root def _root_decomposition_size(self) -> int: @@ -112,10 +113,14 @@ def _root_decomposition_size(self) -> int: def _size(self) -> torch.Size: return torch.Size((*self.root.batch_shape, self.root.size(-2), self.root.size(-2))) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) eval_root = self.root.to_dense() return torch.matmul(eval_root, eval_root.mT) diff --git a/linear_operator/operators/sum_batch_linear_operator.py b/linear_operator/operators/sum_batch_linear_operator.py index 50042a58..feaf85d4 100644 --- a/linear_operator/operators/sum_batch_linear_operator.py +++ b/linear_operator/operators/sum_batch_linear_operator.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -34,7 +33,9 @@ def _add_batch_dim(self, other): other = other.reshape(*shape).expand(*expand_shape) return other - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) diag = self.base_linear_op._diagonal().sum(-2) return diag @@ -61,5 +62,7 @@ def _size(self) -> torch.Size: del shape[-3] return torch.Size(shape) - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return self.base_linear_op.to_dense().sum(dim=-3) # BlockLinearOperators always use dim3 for the block_dim diff --git a/linear_operator/operators/sum_kronecker_linear_operator.py b/linear_operator/operators/sum_kronecker_linear_operator.py index c2930ffe..68e86f9e 100644 --- a/linear_operator/operators/sum_kronecker_linear_operator.py +++ b/linear_operator/operators/sum_kronecker_linear_operator.py @@ -2,7 +2,6 @@ from typing import Callable, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import LinearOperator @@ -39,15 +38,15 @@ def _sum_formulation(self): return inv_root_times_lt1 def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: inner_mat = self._sum_formulation @@ -64,14 +63,16 @@ def _solve( return res - def _logdet(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, " *batch"]: + def _logdet( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch) inner_mat = self._sum_formulation lt2_logdet = self.linear_ops[1].logdet() return inner_mat._logdet() + lt2_logdet def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) inner_mat = self._sum_formulation lt2_root = KroneckerProductLinearOperator( *[lt.root_decomposition().root for lt in self.linear_ops[1].linear_ops] @@ -81,10 +82,10 @@ def _root_decomposition( return root def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) inner_mat = self._sum_formulation lt2_root_inv = self.linear_ops[1].root_inv_decomposition().root inner_mat_root_inv = inner_mat.root_inv_decomposition().root @@ -92,13 +93,13 @@ def _root_inv_decomposition( return inv_root def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: inv_quad_term = None logdet_term = None diff --git a/linear_operator/operators/sum_linear_operator.py b/linear_operator/operators/sum_linear_operator.py index 3fb511b1..03fa196b 100644 --- a/linear_operator/operators/sum_linear_operator.py +++ b/linear_operator/operators/sum_linear_operator.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -26,12 +25,14 @@ def __init__(self, *linear_ops, **kwargs): self.linear_ops = linear_ops - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) expanded_tensors = [linear_op._expand_batch(batch_shape) for linear_op in self.linear_ops] return self.__class__(*expanded_tensors) @@ -44,14 +45,14 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return SumLinearOperator(*results) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return sum(linear_op._matmul(rhs) for linear_op in self.linear_ops) def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) # We're using a custom method here - the constant mul is applied to the base_linear_ops return self.__class__(*[lt._mul_constant(other) for lt in self.linear_ops]) @@ -67,23 +68,27 @@ def _sum_batch(self, dim: int) -> LinearOperator: return self.__class__(*(linear_op._sum_batch(dim) for linear_op in self.linear_ops)) def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) return sum(linear_op._t_matmul(rhs) for linear_op in self.linear_ops) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) linear_ops_t = [linear_op.mT for linear_op in self.linear_ops] return self.__class__(*linear_ops_t) @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return (sum(linear_op.to_dense() for linear_op in self.linear_ops)).contiguous() def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator from linear_operator.operators.diag_linear_operator import DiagLinearOperator diff --git a/linear_operator/operators/toeplitz_linear_operator.py b/linear_operator/operators/toeplitz_linear_operator.py index 90057e2e..128873e8 100644 --- a/linear_operator/operators/toeplitz_linear_operator.py +++ b/linear_operator/operators/toeplitz_linear_operator.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -23,15 +22,17 @@ def __init__(self, column): super(ToeplitzLinearOperator, self).__init__(column) self.column = column - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) diag_term = self.column[..., 0] if self.column.ndimension() > 1: diag_term = diag_term.unsqueeze(-1) return diag_term.expand(*self.column.size()) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.column.expand(*batch_shape, self.column.size(-1))) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: @@ -39,15 +40,15 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice return self.column[(*batch_indices, toeplitz_indices)] def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return sym_toeplitz_matmul(self.column, rhs) def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) # Matrix is symmetric return self._matmul(rhs) @@ -67,12 +68,14 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O def _size(self) -> torch.Size: return torch.Size((*self.column.shape, self.column.size(-1))) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return ToeplitzLinearOperator(self.column) def add_jitter( - self: Float[LinearOperator, "*batch N N"], jitter_val: float = 1e-3 - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, jitter_val: float = 1e-3 # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) jitter = torch.zeros_like(self.column) jitter.narrow(-1, 0, 1).fill_(jitter_val) return ToeplitzLinearOperator(self.column.add(jitter)) diff --git a/linear_operator/operators/triangular_linear_operator.py b/linear_operator/operators/triangular_linear_operator.py index 3dbeb883..2fb709eb 100644 --- a/linear_operator/operators/triangular_linear_operator.py +++ b/linear_operator/operators/triangular_linear_operator.py @@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -51,9 +50,9 @@ def __init__(self, tensor: Allsor, upper: bool = False) -> None: self._tensor = tensor def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): @@ -65,15 +64,15 @@ def __add__( return self._tensor + other def _cholesky( - self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) raise NotPSDError("TriangularLinearOperator does not allow a Cholesky decomposition") def _cholesky_solve( - self: Float[LinearOperator, "*batch N N"], - rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]], + self: LinearOperator, # shape: (*batch, N, N) + rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) upper: Optional[bool] = False, - ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) # use custom method if implemented try: res = self._tensor._cholesky_solve(rhs=rhs, upper=upper) @@ -88,12 +87,14 @@ def _cholesky_solve( res = self._transpose_nonbatch().solve(w) return res - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) return self._tensor._diagonal() def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) if len(batch_shape) == 0: return self return self.__class__(tensor=self._tensor._expand_batch(batch_shape), upper=self.upper) @@ -102,41 +103,41 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice return self._tensor._get_indices(row_index, col_index, *batch_indices) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self._tensor.matmul(rhs) def _mul_constant( - self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] - ) -> Float[LinearOperator, "*batch M N"]: + self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) return self.__class__(self._tensor * other.unsqueeze(-1), upper=self.upper) def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) raise NotPSDError("TriangularLinearOperator does not allow a root decomposition") def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) raise NotPSDError("TriangularLinearOperator does not allow an inverse root decomposition") def _size(self) -> torch.Size: return self._tensor.shape def _solve( - self: Float[LinearOperator, "... N N"], - rhs: Float[torch.Tensor, "... N C"], - preconditioner: Optional[Callable[[Float[torch.Tensor, "... N C"]], Float[torch.Tensor, "... N C"]]] = None, + self: LinearOperator, # shape: (..., N, N) + rhs: torch.Tensor, # shape: (..., N, C) + preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) num_tridiag: Optional[int] = 0, ) -> Union[ - Float[torch.Tensor, "... N C"], + torch.Tensor, # shape: (..., N, C) Tuple[ - Float[torch.Tensor, "... N C"], - Float[torch.Tensor, "..."], # Note that in case of a tuple the second term size depends on num_tridiag + torch.Tensor, # shape: (..., N, C) + torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ], ]: # already triangular, can just call solve for the solve @@ -145,7 +146,9 @@ def _solve( def _sum_batch(self, dim: int) -> LinearOperator: return self.__class__(self._tensor._sum_batch(dim), upper=self.upper) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__(self._tensor._transpose_nonbatch(), upper=not self.upper) def abs(self) -> LinearOperator: @@ -155,29 +158,33 @@ def abs(self) -> LinearOperator: return self.__class__(self._tensor.abs(), upper=self.upper) def add_diagonal( - self: Float[LinearOperator, "*batch N N"], - diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, # shape: (*batch, N, N) + diag: torch.Tensor, # shape: (..., N) or (..., 1) or () + ) -> LinearOperator: # shape: (*batch, N, N) added_diag_lt = self._tensor.add_diagonal(diag) return self.__class__(added_diag_lt, upper=self.upper) - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return self._tensor.to_dense() - def exp(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: + def exp( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a TriangleLinearOperator with all diagonal entries exponentiated. """ return self.__class__(self._tensor.exp(), upper=self.upper) def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: if inv_quad_rhs is None: inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device) @@ -196,7 +203,9 @@ def inv_quad_logdet( return inv_quad_term, logdet_term @cached - def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, "*batch N N"]: + def inverse( + self: LinearOperator, # shape: (*batch, N, N) + ) -> LinearOperator: # shape: (*batch, N, N) """ Returns the inverse of the DiagLinearOperator. """ @@ -205,10 +214,10 @@ def inverse(self: Float[LinearOperator, "*batch N N"]) -> Float[LinearOperator, return self.__class__(inv, upper=self.upper) def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) squeeze = False if right_tensor.dim() == 1: right_tensor = right_tensor.unsqueeze(-1) diff --git a/linear_operator/operators/zero_linear_operator.py b/linear_operator/operators/zero_linear_operator.py index 8c6d5867..94481c3b 100644 --- a/linear_operator/operators/zero_linear_operator.py +++ b/linear_operator/operators/zero_linear_operator.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch -from jaxtyping import Float from torch import Tensor from linear_operator.operators._linear_operator import IndexType, LinearOperator @@ -43,13 +42,15 @@ def device(self) -> Optional[torch.device]: def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: raise RuntimeError("Backwards through a ZeroLinearOperator is not possible") - def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + def _diagonal( + self: LinearOperator, # shape: (..., M, N) + ) -> torch.Tensor: # shape: (..., N) shape = self.shape return torch.zeros(shape[:-1], dtype=self.dtype, device=self.device) def _expand_batch( - self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + ) -> LinearOperator: # shape: (..., M, N) return self.__class__(*batch_shape, *self.sizes[-2:], dtype=self._dtype, device=self._device) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: @@ -61,9 +62,9 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return ZeroLinearOperator(*new_size) def _matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], - ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) + ) -> torch.Tensor: # shape: (..., M, C) or (..., M) rhs_size_ind = -2 if rhs.ndimension() > 1 else -1 if self.size(-1) != rhs.size(rhs_size_ind): raise RuntimeError("Size mismatch, self: {}, rhs: {}".format(self.size(), rhs.size())) @@ -82,18 +83,18 @@ def _prod_batch(self, dim: int) -> LinearOperator: return self.__class__(*sizes, dtype=self._dtype, device=self._device) def _root_decomposition( - self: Float[LinearOperator, "... N N"] - ) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]: + self: LinearOperator, # shape: (..., N, N) + ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) raise RuntimeError("ZeroLinearOperators are not positive definite!") def _root_decomposition_size(self) -> int: raise RuntimeError("ZeroLinearOperators are not positive definite!") def _root_inv_decomposition( - self: Float[LinearOperator, "*batch N N"], + self: LinearOperator, # shape: (*batch, N, N) initial_vectors: Optional[torch.Tensor] = None, test_vectors: Optional[torch.Tensor] = None, - ) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]: + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) raise RuntimeError("ZeroLinearOperators are not positive definite!") def _size(self) -> torch.Size: @@ -105,9 +106,9 @@ def _sum_batch(self, dim: int) -> LinearOperator: return self.__class__(*sizes, dtype=self._dtype, device=self._device) def _t_matmul( - self: Float[LinearOperator, "*batch M N"], - rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], - ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + self: LinearOperator, # shape: (*batch, M, N) + rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) + ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) rhs_size_ind = -2 if rhs.ndimension() > 1 else -1 if self.size(-2) != rhs.size(rhs_size_ind): raise RuntimeError("Size mismatch, self: {}, rhs: {}".format(self.size(), rhs.size())) @@ -120,7 +121,9 @@ def _t_matmul( output_shape = (*batch_shape, new_m, n) return torch.zeros(*output_shape, dtype=rhs.dtype, device=rhs.device) - def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + def _transpose_nonbatch( + self: LinearOperator, # shape: (*batch, M, N) + ) -> LinearOperator: # shape: (*batch, N, M) return self.mT def _unsqueeze_batch(self, dim: int) -> LinearOperator: @@ -129,9 +132,9 @@ def _unsqueeze_batch(self, dim: int) -> LinearOperator: return self.__class__(*sizes, dtype=self._dtype, device=self._device) def add_diagonal( - self: Float[LinearOperator, "*batch N N"], - diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], - ) -> Float[LinearOperator, "*batch N N"]: + self: LinearOperator, # shape: (*batch, N, N) + diag: torch.Tensor, # shape: (..., N) or (..., 1) or () + ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if self.size(-1) != self.size(-2): @@ -172,30 +175,32 @@ def div(self, other: Union[float, torch.Tensor]) -> LinearOperator: return self def inv_quad( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]], + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Tensor, # shape: (*batch, N, M) or (*batch, N) reduce_inv_quad: bool = True, - ) -> Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"]]: + ) -> Tensor: # shape: (*batch, M) or (*batch) raise RuntimeError("ZeroLinearOperators are not invertible!") def inv_quad_logdet( - self: Float[LinearOperator, "*batch N N"], - inv_quad_rhs: Optional[Union[Float[Tensor, "*batch N M"], Float[Tensor, "*batch N"]]] = None, + self: LinearOperator, # shape: (*batch, N, N) + inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, ) -> Tuple[ - Optional[Union[Float[Tensor, "*batch M"], Float[Tensor, " *batch"], Float[Tensor, " 0"]]], - Optional[Float[Tensor, "..."]], + Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) + Optional[Tensor], # shape: (...) ]: raise RuntimeError("ZeroLinearOperators are not invertible!") - def logdet(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, " *batch"]: + def logdet( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch) return torch.log(torch.tensor(0.0)) def matmul( - self: Float[LinearOperator, "*batch M N"], - other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], - ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) + ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) tensor_size_ind = -2 if other.ndimension() > 1 else -1 if self.size(-1) != other.size(tensor_size_ind): raise RuntimeError("Size mismatch, self: {}, other: {}".format(self.size(), other.size())) @@ -209,21 +214,23 @@ def matmul( return ZeroLinearOperator(*output_shape, dtype=other.dtype, device=other.device) def mul( - self: Float[LinearOperator, "*batch M N"], - other: Union[float, Float[Tensor, "*batch2 M N"], Float[LinearOperator, "*batch2 M N"]], - ) -> Float[LinearOperator, "... M N"]: + self: LinearOperator, # shape: (*batch, M, N) + other: Union[float, Tensor, LinearOperator], # shape: (*batch2, M, N) + ) -> LinearOperator: # shape: (..., M, N) shape = torch.broadcast_shapes(self.shape, other.shape) return self.__class__(*shape, dtype=self._dtype, device=self._device) def solve( - self: Float[LinearOperator, "... N N"], - right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]], - left_tensor: Optional[Float[Tensor, "... O N"]] = None, - ) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]: + self: LinearOperator, # shape: (..., N, N) + right_tensor: Tensor, # shape: (..., N, P) or (N) + left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) raise RuntimeError("ZeroLinearOperators are not invertible!") @cached - def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + def to_dense( + self: LinearOperator, # shape: (*batch, M, N) + ) -> Tensor: # shape: (*batch, M, N) return torch.zeros(*self.sizes) def transpose(self, dim1: int, dim2: int) -> LinearOperator: @@ -235,7 +242,7 @@ def transpose(self, dim1: int, dim2: int) -> LinearOperator: return ZeroLinearOperator(*sizes) def __add__( - self: Float[LinearOperator, "... #M #N"], - other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], - ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: + self: LinearOperator, # shape: (..., #M, #N) + other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) + ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) return other diff --git a/linear_operator/settings.py b/linear_operator/settings.py index 14c9107b..98332e2e 100644 --- a/linear_operator/settings.py +++ b/linear_operator/settings.py @@ -331,6 +331,7 @@ class fast_computations: .. _GPyTorch Blackbox Matrix-Matrix Gaussian Process Inference with GPU Acceleration: https://arxiv.org/pdf/1809.11165.pdf """ + covar_root_decomposition = _fast_covar_root_decomposition log_prob = _fast_log_prob solves = _fast_solves diff --git a/linear_operator/test/type_checking_test_case.py b/linear_operator/test/type_checking_test_case.py deleted file mode 100644 index c66ac9f8..00000000 --- a/linear_operator/test/type_checking_test_case.py +++ /dev/null @@ -1,97 +0,0 @@ -# These are tests to directly check torchtyping signatures as extended to LinearOperator. -# The idea is to verify that dimension tests are working as expected. -import unittest -from typing import Union - -import torch -from jaxtyping import Float, jaxtyped - -# Use your favourite typechecker: usually one of the two lines below. -from typeguard import typechecked as typechecker - -from linear_operator.operators import DenseLinearOperator, LinearOperator - - -@jaxtyped -@typechecker -def linop_matmul_fn( - lo: Float[LinearOperator, "*batch M N"], - vec: Union[Float[torch.Tensor, "*batch N C"], Float[torch.Tensor, "*batch N"]], -) -> Union[Float[torch.Tensor, "*batch M C"], Float[torch.Tensor, "*batch M"]]: - r""" - Performs a matrix multiplication :math:`\mathbf KM` with the (... x M x N) matrix :math:`\mathbf K` - that lo represents. Should behave as - :func:`torch.matmul`. If the LinearOperator represents a batch of - matrices, this method should therefore operate in batch mode as well. - - :param lo: the K = MxN left hand matrix - :param vec: the matrix :math:`\mathbf M` to multiply with (... x N x C). - :return: :math:`\mathbf K \mathbf M` (... x M x C) - """ - res = lo.matmul(vec) - return res - - -@jaxtyped -@typechecker -def linop_matmul_fn_bad_lo( - lo: Float[LinearOperator, " N"], - vec: Union[Float[torch.Tensor, "*batch N C"], Float[torch.Tensor, "*batch N"]], -) -> Union[Float[torch.Tensor, "*batch M C"], Float[torch.Tensor, "*batch M"]]: - r""" - As above, but with bad size array for lo - """ - res = lo.matmul(vec) - return res - - -@jaxtyped -@typechecker -def linop_matmul_fn_bad_vec( - lo: Float[LinearOperator, "*batch M N"], vec: Float[torch.Tensor, "*batch N C"] -) -> Union[Float[torch.Tensor, "*batch M C"], Float[torch.Tensor, "*batch M"]]: - r""" - As above, but with bad size array for vec - """ - res = lo.matmul(vec) - return res - - -@jaxtyped -@typechecker -def linop_matmul_fn_bad_retn( - lo: Float[LinearOperator, "*batch M N"], - vec: Union[Float[torch.Tensor, "*batch N C"], Float[torch.Tensor, "*batch N"]], -) -> Float[torch.Tensor, "*batch M C"]: - r""" - As above, but with bad return size - """ - res = lo.matmul(vec) - return res - - -class TestTypeChecking(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - mat = torch.tensor([[3, -1, 0], [-1, 3, 0], [0, 0, 3]], dtype=torch.float) - self.vec = torch.randn(3) - self.lo = DenseLinearOperator(mat) - - def test_linop_matmul_fn(self): - linop_matmul_fn(self.lo, self.vec) - - def test_linop_matmul_fn_bad_lo(self): - with self.assertRaises(TypeError): - linop_matmul_fn_bad_lo(self.lo, self.vec) - - def test_linop_matmul_fn_bad_vec(self): - with self.assertRaises(TypeError): - linop_matmul_fn_bad_vec(self.lo, self.vec) - - def test_linop_matmul_fn_bad_retn(self): - with self.assertRaises(TypeError): - linop_matmul_fn_bad_retn(self.lo, self.vec) - - -if __name__ == "__main__": - unittest.main() diff --git a/setup.py b/setup.py index 8a55955d..9352cfdf 100644 --- a/setup.py +++ b/setup.py @@ -37,10 +37,7 @@ pass # Other requirements -install_requires += [ - "scipy", - "jaxtyping", -] +install_requires += ["scipy"] # Get version From 7a4372018019fc7bf0a55a1d1d44dd45cdf83cb0 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 2 Feb 2026 17:12:04 -0800 Subject: [PATCH 2/4] Also remove typeguard comments These aren't necessary --- .../test_kronecker_product_added_diag_linear_operator.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/operators/test_kronecker_product_added_diag_linear_operator.py b/test/operators/test_kronecker_product_added_diag_linear_operator.py index ad9a2d8d..b055088b 100644 --- a/test/operators/test_kronecker_product_added_diag_linear_operator.py +++ b/test/operators/test_kronecker_product_added_diag_linear_operator.py @@ -16,13 +16,6 @@ ) from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase -# requires typeguard 3.1+ -# import typeguard -# import pytest -# @pytest.fixture(autouse=True) -# def suppress_typeguard(): -# yield typeguard.suppress_type_checks() - class TestKroneckerProductAddedDiagLinearOperator(unittest.TestCase, LinearOperatorTestCase): # this lazy tensor has an explicit inverse so we don't need to run these From d54e35be48db52b35bd2b3fbc72d37c8925f2130 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 2 Feb 2026 17:15:04 -0800 Subject: [PATCH 3/4] Fix unused import lint --- linear_operator/operators/keops_linear_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linear_operator/operators/keops_linear_operator.py b/linear_operator/operators/keops_linear_operator.py index 76af001b..0f1981cf 100644 --- a/linear_operator/operators/keops_linear_operator.py +++ b/linear_operator/operators/keops_linear_operator.py @@ -2,7 +2,7 @@ import warnings -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from torch import Tensor From 347f9fefc072f3488de13c95cbe8c372d3f83d9a Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 2 Feb 2026 17:55:07 -0800 Subject: [PATCH 4/4] Avoid reformatting of code blocks with shape comments for readability --- linear_operator/operators/_linear_operator.py | 4 ++-- .../operators/batch_repeat_linear_operator.py | 4 ++-- .../operators/block_diag_linear_operator.py | 4 ++-- .../operators/block_interleaved_linear_operator.py | 4 ++-- linear_operator/operators/cat_linear_operator.py | 4 ++-- linear_operator/operators/chol_linear_operator.py | 4 ++-- linear_operator/operators/diag_linear_operator.py | 4 ++-- .../operators/identity_linear_operator.py | 4 ++-- .../kronecker_product_added_diag_linear_operator.py | 4 ++-- .../operators/kronecker_product_linear_operator.py | 4 ++-- .../low_rank_root_added_diag_linear_operator.py | 4 ++-- .../operators/sum_kronecker_linear_operator.py | 4 ++-- .../operators/triangular_linear_operator.py | 12 ++++++------ linear_operator/operators/zero_linear_operator.py | 4 ++-- 14 files changed, 32 insertions(+), 32 deletions(-) diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index 3e301032..9e814637 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -1682,10 +1682,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on r""" Calls both :func:`inv_quad` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`. However, calling this diff --git a/linear_operator/operators/batch_repeat_linear_operator.py b/linear_operator/operators/batch_repeat_linear_operator.py index 7188d32b..333e8f5c 100644 --- a/linear_operator/operators/batch_repeat_linear_operator.py +++ b/linear_operator/operators/batch_repeat_linear_operator.py @@ -260,10 +260,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if not self.is_square: raise RuntimeError( "inv_quad_logdet only operates on (batches of) square (positive semi-definite) LinearOperators. " diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index e5b97d0d..a523abc8 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -165,10 +165,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if inv_quad_rhs is not None: inv_quad_rhs = self._add_batch_dim(inv_quad_rhs) inv_quad_res, logdet_res = self.base_linear_op.inv_quad_logdet( diff --git a/linear_operator/operators/block_interleaved_linear_operator.py b/linear_operator/operators/block_interleaved_linear_operator.py index ba6f7d39..6b077936 100644 --- a/linear_operator/operators/block_interleaved_linear_operator.py +++ b/linear_operator/operators/block_interleaved_linear_operator.py @@ -130,10 +130,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if inv_quad_rhs is not None: inv_quad_rhs = self._add_batch_dim(inv_quad_rhs) inv_quad_res, logdet_res = self.base_linear_op.inv_quad_logdet( diff --git a/linear_operator/operators/cat_linear_operator.py b/linear_operator/operators/cat_linear_operator.py index 2e8974a4..b16574a6 100644 --- a/linear_operator/operators/cat_linear_operator.py +++ b/linear_operator/operators/cat_linear_operator.py @@ -393,10 +393,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on res = super().inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad) return tuple(r.to(self.device) for r in res) diff --git a/linear_operator/operators/chol_linear_operator.py b/linear_operator/operators/chol_linear_operator.py index 3be5dd92..2983b9b0 100644 --- a/linear_operator/operators/chol_linear_operator.py +++ b/linear_operator/operators/chol_linear_operator.py @@ -123,10 +123,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if not self.is_square: raise RuntimeError( "inv_quad_logdet only operates on (batches of) square (positive semi-definite) LinearOperators. " diff --git a/linear_operator/operators/diag_linear_operator.py b/linear_operator/operators/diag_linear_operator.py index 03c485f0..b4d56789 100644 --- a/linear_operator/operators/diag_linear_operator.py +++ b/linear_operator/operators/diag_linear_operator.py @@ -165,10 +165,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on # TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append) if inv_quad_rhs is None: rhs_batch_shape = torch.Size() diff --git a/linear_operator/operators/identity_linear_operator.py b/linear_operator/operators/identity_linear_operator.py index 1d9b75c9..caee46cf 100644 --- a/linear_operator/operators/identity_linear_operator.py +++ b/linear_operator/operators/identity_linear_operator.py @@ -189,10 +189,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on # TODO: Use proper batching for inv_quad_rhs (prepand to shape rather than append) if inv_quad_rhs is None: inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device) diff --git a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py index 5e703c58..d362967c 100644 --- a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py +++ b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py @@ -67,10 +67,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if inv_quad_rhs is not None: inv_quad_term, _ = super().inv_quad_logdet( inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad diff --git a/linear_operator/operators/kronecker_product_linear_operator.py b/linear_operator/operators/kronecker_product_linear_operator.py index 3a98c4a5..f76fe37e 100644 --- a/linear_operator/operators/kronecker_product_linear_operator.py +++ b/linear_operator/operators/kronecker_product_linear_operator.py @@ -164,10 +164,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if inv_quad_rhs is not None: inv_quad_term, _ = super().inv_quad_logdet( inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad diff --git a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py index 705552f3..541d9493 100644 --- a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py +++ b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py @@ -112,10 +112,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if not self.is_square: raise RuntimeError( "inv_quad_logdet only operates on (batches of) square (positive semi-definite) LinearOperators. " diff --git a/linear_operator/operators/sum_kronecker_linear_operator.py b/linear_operator/operators/sum_kronecker_linear_operator.py index 68e86f9e..760a08aa 100644 --- a/linear_operator/operators/sum_kronecker_linear_operator.py +++ b/linear_operator/operators/sum_kronecker_linear_operator.py @@ -97,10 +97,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on inv_quad_term = None logdet_term = None diff --git a/linear_operator/operators/triangular_linear_operator.py b/linear_operator/operators/triangular_linear_operator.py index 2fb709eb..201b2700 100644 --- a/linear_operator/operators/triangular_linear_operator.py +++ b/linear_operator/operators/triangular_linear_operator.py @@ -182,10 +182,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on if inv_quad_rhs is None: inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device) else: @@ -245,15 +245,15 @@ def solve_triangular( ) -> torch.Tensor: if upper != self.upper: raise RuntimeError( - f"Incompatible argument: {self.__class__.__name__}.solve_triangular called with `upper={upper}`, " - f"but `LinearOperator` has `upper={self.upper}`." + f"Incompatible argument: {self.__class__.__name__}.solve_triangular called with 'upper={upper}', " + f"but 'LinearOperator' has 'upper={self.upper}'." ) if not left: raise NotImplementedError( - f"Argument `left=False` not yet supported for {self.__class__.__name__}.solve_triangular." + f"Argument 'left=False' not yet supported for {self.__class__.__name__}.solve_triangular." ) if unitriangular: raise NotImplementedError( - f"Argument `unitriangular=True` not yet supported for {self.__class__.__name__}.solve_triangular." + f"Argument 'unitriangular=True' not yet supported for {self.__class__.__name__}.solve_triangular." ) return self.solve(right_tensor=rhs) diff --git a/linear_operator/operators/zero_linear_operator.py b/linear_operator/operators/zero_linear_operator.py index 94481c3b..f6665a67 100644 --- a/linear_operator/operators/zero_linear_operator.py +++ b/linear_operator/operators/zero_linear_operator.py @@ -186,10 +186,10 @@ def inv_quad_logdet( inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) logdet: Optional[bool] = False, reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ + ) -> Tuple[ # fmt: off Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) Optional[Tensor], # shape: (...) - ]: + ]: # fmt: on raise RuntimeError("ZeroLinearOperators are not invertible!") def logdet(