Source code for tad_mctc.storch.linalg

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SafeOps: Linear Algebra
=======================

A collection of common mathematical functions.

This module contains a collection of batch-operable, back-propagatable
mathematical functions.

Taken from TBMaLT.
https://github.com/tbmalt/tbmalt/blob/main/tbmalt/common/maths/__init__.py
"""

from __future__ import annotations

import numpy as np
import torch

from .._version import __tversion__
from ..convert import symmetrize
from ..typing import Any, Callable, Literal, Tensor

__all__ = ["eighb"]


def estimate_minmax(amat: Tensor) -> tuple[Tensor, Tensor]:
    """
    Estimate maximum and minimum eigenvalue of a matrix using the Gershgorin
    circle theorem.

    Parameters
    ----------
    amat : Tensor
        A symmetric matrix for which the maximum and minimum eigenvalues are to be estimated.

    Returns
    -------
    tuple of Tensor
        A tuple containing two tensors. The first tensor represents the estimated
        minimum eigenvalue, and the second tensor represents the estimated maximum
        eigenvalue of the input matrix.

    Examples
    --------
    >>> amat = torch.tensor([
    ...     [[-1.1258, -0.1794,  0.1126],
    ...      [-0.1794,  0.5988,  0.1490],
    ...      [ 0.1126,  0.1490,  0.4681]],
    ...     [[-0.1577,  0.6080, -0.3301],
    ...      [ 0.6080,  1.5863,  0.9391],
    ...      [-0.3301,  0.9391,  1.2590]],
    ... ])
    >>> estimate_minmax(amat)
    (tensor([-1.4178, -1.0958]), tensor([0.9272, 3.1334]))
    >>> evals = torch.linalg.eigh(amat)[0]
    >>> evals.min(-1)[0], evals.max(-1)[0]
    (tensor([-1.1543, -0.5760]), tensor([0.7007, 2.4032]))

    Notes
    -----
    This function applies the Gershgorin circle theorem to estimate the
    minimum and maximum eigenvalues of a symmetric matrix. These estimates
    provide bounds but may not be exact eigenvalues.
    """
    center = amat.diagonal(dim1=-2, dim2=-1)
    radius = torch.sum(torch.abs(amat), dim=-1) - torch.abs(center)

    return (
        torch.min(center - radius, dim=-1)[0],
        torch.max(center + radius, dim=-1)[0],
    )


class SymEigBroadBase(torch.autograd.Function):
    r"""
    Solves standard eigenvalue problems for real symmetric matrices,
    suitable for solving multiple systems with batch processing where
    the first dimension iterates over instances of the batch.

    This function can apply conditional or Lorentzian broadening to the
    eigenvalues during the backwards pass to increase gradient stability.

    Parameters
    ----------
    a : array_like
        A real symmetric matrix whose eigenvalues & eigenvectors will be computed.
    method : {'cond', 'lorn'}, optional
        Broadening method to use. 'cond' refers to conditional broadening,
        'lorn' to Lorentzian broadening. Default is 'cond'.
    factor : Tensor | float, optional
        Degree of broadening (broadening factor). Default is 1E-12.

    Returns
    -------
    w : ndarray
        The eigenvalues, in ascending order.
    v : ndarray
        The eigenvectors.

    Notes
    -----
    Results from backward passes through eigen-decomposition operations
    tend to suffer from numerical stability issues, especially when operating
    on systems with degenerate eigenvalues. Applying eigenvalue broadening
    increases stability but introduces small errors in the gradients. The
    extent of broadening correlates with the stability improvement and the
    error magnitude.

    Two broadening methods are implemented: Conditional broadening as
    described by Seeger [MS2019]_, and Lorentzian as detailed by Liao [LH2019]_.
    The forward pass uses `torch.symeig` to calculate eigenvalues and eigenvectors.
    The gradient is calculated as:

    .. math:: \bar{A} = U (\bar{\Lambda} + sym(F \circ (U^t \bar{U}))) U^T

    where :math:`\bar{\Lambda}` is the diagonal matrix of the eigenvalue gradients,
    :math:`\circ` denotes the Hadamard product, and `sym` is the symmetrisation
    operator. F is defined as :math:`F_{i, j} = \frac{I_{i \ne j}}{h(\lambda_i - \lambda_j)}`
    with `h` being a function specific to the chosen broadening method.

    Conditional broadening applies only when necessary, limiting gradient errors
    to systems that would otherwise yield NaNs. Lorentzian broadening affects all
    systems regardless of necessity. Without broadening, the backward pass
    resembles a standard eigen-solver.

    References
    ----------
    .. [MS2019] Seeger, M., Hetzel, A., Dai, Z., & Meissner, E. Auto-
                Differentiating Linear Algebra. ArXiv:1710.08717 [Cs,
                Stat], Aug. 2019. arXiv.org, http://arxiv.org/abs/1710.08717.
    .. [LH2019] Liao, H.-J., Liu, J.-G., Wang, L., & Xiang, T. (2019).
                Differentiable Programming Tensor Networks. Physical
                Review X, 9(3).
    .. [Lapack] www.netlib.org/lapack/lug/node54.html (Accessed 21/04/2023)

    """

    # Note that 'none' is included only for testing purposes
    KNOWN_METHODS = ["cond", "lorn", "none", None]

    @staticmethod
    def backward(  # type: ignore[override]
        ctx: Any,
        w_bar: Tensor,
        v_bar: Tensor,
    ) -> tuple[Tensor, None, None]:
        """
        Evaluates gradients of the eigen decomposition operation.

        This method evaluates the gradients of the matrix from which the
        eigenvalues and eigenvectors were originally computed during the
        forward pass.

        Parameters
        ----------
        ctx : Any
            Context object containing information for backward computation.
        w_bar : Tensor
            Gradients associated with the eigenvalues.
        v_bar : Tensor
            Gradients associated with the eigenvectors.

        Returns
        -------
        tuple[Tensor, None, None]
            A tuple containing the gradient of the input matrix and two None
            placeholders for method and factor, which do not require gradients.
            The first element (gradient of the input matrix) is of type Tensor,
            while the other two elements are None.

        Notes
        -----
        This method should only be called by PyTorch's automatic differentiation
        mechanism. The ctx parameter provides saved tensors from the forward
        pass that are necessary for computing the gradients.

        For a more detailed description of the gradient computation, refer to
        the class docstring.
        """
        # Equation to variable legend
        #   w <- λ
        #   v <- U

        # __Preamble__
        # Retrieve eigenvalues (w) and eigenvectors (v) from ctx
        w: Tensor = ctx.saved_tensors[0]
        v: Tensor = ctx.saved_tensors[1]

        # Retrieve, the broadening factor and convert to a tensor entity
        if not isinstance(ctx.bf, Tensor):
            bf = torch.tensor(ctx.bf, dtype=ctx.dtype, device=ctx.device)
        else:
            bf = ctx.bf

        # Retrieve the broadening method
        bm = ctx.bm

        # Form the eigenvalue gradients into diagonal matrix
        lambda_bar = w_bar.diag_embed()

        # Identify the indices of the upper triangle of the F matrix
        rows, cols = v.shape[-2:]
        tri_u = torch.triu_indices(*(rows, cols), offset=1)

        # Construct the deltas
        deltas = w[..., tri_u[1]] - w[..., tri_u[0]]

        # Apply broadening
        if bm == "cond":  # <- Conditional broadening
            deltas = (
                1
                / torch.where(torch.abs(deltas) > bf, deltas, bf)
                * torch.sign(deltas)
            )
        elif bm == "lorn":  # <- Lorentzian broadening
            deltas = deltas / (deltas**2 + bf)
        elif bm == "none":  # <- Debugging only
            deltas = 1 / deltas
        else:  # pragma: no cover
            # Should be impossible to get here
            raise ValueError(f"Unknown broadening method {bm}")

        # Construct F matrix where F_ij = v_bar_j - v_bar_i; construction is
        # done in this manner to avoid 1/0 which can cause intermittent and
        # hard-to-diagnose issues.
        F = torch.zeros(
            *w.shape, w.shape[-1], dtype=ctx.dtype, device=w_bar.device
        )
        # Upper then lower triangle
        F[..., tri_u[0], tri_u[1]] = deltas
        F[..., tri_u[1], tri_u[0]] -= F[..., tri_u[0], tri_u[1]]

        # Construct the gradient following the equation in the doc-string.
        temp = symmetrize(F * (v.transpose(-2, -1) @ v_bar), force=True)
        a_bar = v @ (lambda_bar + temp) @ v.transpose(-2, -1)

        # Return the gradient. PyTorch expects a gradient for each parameter
        # (method, bf) hence two extra Nones are returned
        return a_bar, None, None


class _SymEigBroad_V1(SymEigBroadBase):  # pragma: no cover
    """
    Calculate the eigenvalues and eigenvectors of a symmetric matrix with a
    custom autograd function that defines a `forward()` that combines the
    forward compute logic with `setup_context()` function. This was the only
    way before PyTorch 2.0.0, but is still supported.

    More details can be found in the docstring of the Base class that also
    implements the common backward logic.
    """

    @staticmethod
    def forward(  # type: ignore[override]
        ctx: Any,
        a: Tensor,
        method: str = "cond",
        factor: Tensor | float = 1e-12,
    ) -> tuple[Tensor, Tensor]:
        """
        Calculate the eigenvalues and eigenvectors of a symmetric matrix.

        This function finds the eigenvalues and eigenvectors of a real symmetric
        matrix using the torch.symeig function. It optionally applies broadening
        to the eigenvalues during the computation.

        Parameters
        ----------
        a : Tensor
            A real symmetric matrix whose eigenvalues and eigenvectors will be computed.
        method : {'cond', 'lorn'}, optional
            Broadening method to be used. The available options are:
            - 'cond' for conditional broadening.
            - 'lorn' for Lorentzian broadening.
            The default is 'cond'. See class doc-string for more information on
            these methods.
        factor : float, optional
            Degree of broadening (broadening factor). Default is 1E-12.

        Returns
        -------
        w : Tensor
            The eigenvalues of the matrix, in ascending order.
        v : Tensor
            The eigenvectors of the matrix.

        Notes
        -----
        The `ctx` argument is auto-parsed by PyTorch and is used to pass data
        from the  `.forward()` method to the `.backward()` method. This is
        typically not described in the docstring, but is included here for
        clarity.

        Warnings
        --------
        The `factor` should not be a torch.tensor entity. The `method` and
        `factor` parameters must be passed as positional arguments and not
        keyword arguments.
        """

        # Check that the method is of a known type
        if method not in SymEigBroadBase.KNOWN_METHODS:
            raise ValueError(f"Unknown broadening method '{method}' selected.")

        # Compute eigen-values & vectors
        w, v = torch.linalg.eigh(a)

        # Save tensors that will be needed in the backward pass
        ctx.save_for_backward(w, v)

        # Save the broadening factor and the selected broadening method.
        ctx.bf, ctx.bm = factor, method

        # Store dtype/device to prevent dtype/device mixing
        ctx.dtype, ctx.device = a.dtype, a.device

        # Return the eigenvalues and eigenvectors
        return w, v


class _SymEigBroad_V2(SymEigBroadBase):
    """
    Calculate the eigenvalues and eigenvectors of a symmetric matrix with a
    custom autograd function that defines a separate `forward()` and
    `setup_context()` function (PyTorch >= 2.0.0).

    More details can be found in the docstring of the Base class that also
    implements the common backward logic.
    """

    @staticmethod
    def forward(  # type: ignore[override]
        a: Tensor,
        method: str = "cond",
        factor: Tensor | float = 1e-12,
    ) -> tuple[Tensor, Tensor]:
        """
        Calculate the eigenvalues and eigenvectors of a symmetric matrix.

        This method computes the eigenvalues and eigenvectors of a real
        symmetric matrix using the `torch.linalg.eigh` function. It allows for
        applying broadening methods to the eigenvalues.

        Parameters
        ----------
        a : Tensor
            A real symmetric matrix whose eigenvalues and eigenvectors will be
            computed.
        method : str, optional
            The broadening method to be used. Available options are:
            - 'cond' for conditional broadening.
            - 'lorn' for Lorentzian broadening.
            Default is 'cond'.
        factor : float, optional
            The degree of broadening (broadening factor). Default is 1E-12.

        Returns
        -------
        tuple of Tensor
            A tuple containing two tensors. The first tensor (`w`) is the
            eigenvalues in ascending order. The second tensor (`v`) is the
            eigenvectors of the matrix.

        Notes
        -----
        The `ctx` argument is used internally by PyTorch to pass data from the
        `forward` method to the `backward` method. This is not normally part of
        the function signature in user-facing documentation.

        Warnings
        --------
        The `factor` should not be a `torch.tensor` entity. Both `method` and
        `factor` parameters must be passed as positional arguments, not keyword
        arguments.
        """
        # Check that the method is of a known type
        if method not in SymEigBroadBase.KNOWN_METHODS:
            raise ValueError(f"Unknown broadening method '{method}' selected.")

        # Compute eigen-values & vectors
        w, v = torch.linalg.eigh(a)

        # Return the eigenvalues and eigenvectors
        return w, v

    @staticmethod
    def setup_context(
        ctx: Any, inputs: tuple[Any, ...], outputs: tuple[Tensor, Tensor]
    ) -> None:
        """
        Sets up the context for backward computation in a PyTorch autograd
        function.

        This method is used to save necessary tensors and other information
        from the forward pass to be used in the backward pass for gradient
        computation.

        Parameters
        ----------
        ctx : Any
            The context object used to store information for backward
            computation.
        inputs : tuple
            A tuple containing inputs to the forward method. It should include
            the matrix `a`, the broadening method `method`, and the broadening
            factor `factor`.
        outputs : tuple of Tensor
            A tuple containing the outputs from the forward pass, which are the
            eigenvalues and eigenvectors of the matrix.

        Notes
        -----
        This method is specific to PyTorch's autograd mechanism and is not
        intended to be called directly by users. It is automatically invoked
        during the forward pass of a custom autograd function.
        """
        a: Tensor = inputs[0]
        method: str = inputs[1]
        factor: Tensor | float = inputs[2]

        w, v = outputs

        # Save tensors that will be needed in the backward pass
        ctx.save_for_backward(w, v)

        # Save the broadening factor and the selected broadening method.
        ctx.bf, ctx.bm = factor, method

        # Store dtype/device to prevent dtype/device mixing
        ctx.dtype, ctx.device = a.dtype, a.device


def _eig_sort_out(
    w: Tensor,
    v: Tensor,
    ghost: bool = True,
) -> tuple[Tensor, Tensor]:
    """
    Move ghost eigenvalues/vectors to the end of the array.

    This function addresses the issue of ghost eigenvalues/vectors that emerge
    from performing eigen-decomposition on zero-padded packed tensors. Ghosts
    are relocated to the end of the arrays for easy removal.

    Parameters
    ----------
    w : Tensor
        The eigenvalues.
    v : Tensor
        The eigenvectors.
    ghost : bool, optional
        Indicator of the nature of ghost eigenvalues. If True, ghost eigenvalues
        are assumed to be 0. If False, they are assumed to be 1. This should be
        set to True for zero-padded tensors and False for identity-padded
        tensors. Defaults to True. Changing this flag also adjusts ghost
        eigenvalues from 1 to 0 when appropriate.

    Returns
    -------
    Tensor
        The modified eigenvalues with ghosts moved to the end.
    Tensor
        The modified eigenvectors with ghosts moved to the end.

    Notes
    -----
    Ghost eigenvalues/vectors typically emerge when eigen-decomposition is
    performed on matrices that have been zero-padded. These can interfere with
    downstream processes. This function separates them by moving them to the
    end of the tensor, facilitating their removal if desired.

    The term 'ghost' refers to eigenvalues of 0, while 'auxiliary' eigenvalues
    are those set to 1. The choice between treating eigenvalues as ghosts or
    auxiliaries depends on how padding is handled in the input tensor.
    """
    val = 0 if ghost else 1

    # Create a mask that is True when an eigen value is zero/one
    mask = torch.eq(w, val)
    # and its associated eigen vector is a column of a identity matrix:
    # i.e. all values are 1 or 0 and there is only a single 1. This will
    # just all zeros if columns are not one-hot.
    is_one = torch.eq(v, 1)  # <- precompute
    mask &= torch.all(torch.eq(v, 0) | is_one, dim=1)
    mask &= torch.sum(is_one, dim=1) <= 1  # <- Only a single "1" at most.

    # Convert any auxiliary eigenvalues into ghosts
    if not ghost:
        w = w - mask.type(w.dtype)

    # Pull out the indices of the true & ghost entries and cat them together
    # so that the ghost entries are at the end.
    # noinspection PyTypeChecker
    indices = torch.cat(
        (torch.stack(torch.where(~mask)), torch.stack(torch.where(mask))),
        dim=-1,
    )

    # argsort fixes the batch order and stops eigen-values accidentally being
    # mixed between different systems. As PyTorch's argsort is not stable, i.e.
    # it dose not respect any order already present in the data, numpy's argsort
    # must be used for now.
    sorter = np.argsort(indices[0].cpu(), kind="stable")

    # Apply sorter to indices; use a tuple to make 1D & 2D cases compatible
    sorted_indices = tuple(indices[..., sorter])

    # Fix the order of the eigen values and eigen vectors.
    w = w[sorted_indices].reshape(w.shape)
    # Reshaping is needed to allow sorted_indices to be used for 2D & 3D
    v = v.transpose(-1, -2)[sorted_indices].reshape(v.shape).transpose(-1, -2)

    # Return the eigenvalues and eigenvectors
    return w, v


[docs] def eighb( a: Tensor, b: Tensor | None = None, scheme: Literal["chol", "lowd"] = "chol", broadening_method: Literal["cond", "lorn"] | None = "cond", factor: Tensor | float = 1e-12, sort_out: bool = True, aux: bool = True, is_posdef: bool = False, **kwargs: Any, ) -> tuple[Tensor, Tensor]: r""" Solves general & standard eigen-problems, with optional broadening. Solves standard and generalised eigenvalue problems of the form Az = λBz for a real symmetric matrix `a` and can apply conditional or Lorentzian broadening to the eigenvalues during the backwards pass to increase gradient stability. Multiple matrices may be passed in batch major form, i.e. the first axis iterates over entries. Parameters ---------- a : array_like Real symmetric matrix whose eigen-values/vectors will be computed. b : array_like Complementary positive definite real symmetric matrix for the generalised eigenvalue problem. scheme : str, optional Scheme to convert generalised eigenvalue problems to standard ones. Options are: - "chol": Cholesky factorisation. [DEFAULT='chol'] - "lowd": Löwdin orthogonalisation. Has no effect on solving standard problems. broadening_method : str, optional Broadening method to used. Options are: - "cond": conditional broadening. [DEFAULT='cond'] - "lorn": Lorentzian broadening. - None: no broadening (uses `torch.linalg.eigh`). factor : float, optional The degree of broadening (broadening factor). [Default=1E-12] sort_out : bool, optional If True, eigen-vector/value tensors are reordered so that any "ghost" entries are moved to the end. "Ghost" are values which emerge as a result of zero-padding. [DEFAULT=True] aux : bool, optional Converts zero-padding to identity-padding. This can improve the stability of backwards propagation. [DEFAULT=True] direct_inv : bool, optional If True then the matrix inversion will be computed directly rather than via a call to torch.solve. Only relevant to the cholesky scheme. [DEFAULT=False] Returns ------- w : ndarray The eigenvalues, in ascending order. v : ndarray The eigenvectors. Notes ----- Results from backward passes through eigen-decomposition operations tend to suffer from numerical stability issues when operating on systems with degenerate eigenvalues. Fortunately, the stability of such operations can be increased through the application of eigen value broadening. However, such methods will induce small errors in the returned gradients as they effectively mutate the eigen-values in the backwards pass. Thus, it is important to be aware that while increasing the extent of broadening will help to improve stability it will also increase the error in the gradients. Two different broadening methods have been implemented within this class. Conditional broadening as described by Seeger [MS2019]_, and Lorentzian as detailed by Liao [LH2019]_. During the forward pass the `torch.symeig` function is used to calculate both the eigenvalues & the eigenvectors (U & :math:`\lambda` respectively). The gradient is then calculated following: .. math:: \bar{A} = U (\bar{\Lambda} + sym(F \circ (U^t \bar{U}))) U^T Where bar indicates a value's gradient, passed in from the previous layer, :math:`\Lambda` is the diagonal matrix associated with the :math:`\bar{\lambda}` values, :math:`\circ` is the so called Hadamard product, :math:`sym` is the symmetrisation operator and F is: .. math:: F_{i, j} = \frac{I_{i \ne j}}{h(\lambda_i - \lambda_j)} Where, for conditional broadening, h is: .. math:: h(t) = max(|t|, \epsilon)sgn(t) and for, Lorentzian broadening: .. math:: h(t) = \frac{t^2 + \epsilon}{t} The advantage of conditional broadening is that it is only applied when needed, thus the errors induced in the gradients will be restricted to systems whose gradients would be nan's otherwise. The Lorentzian method, on the other hand, will apply broadening to all systems, irrespective of whether or not it is necessary. Note that if the h function is a unity operator then this is identical to a standard backwards pass through an eigen-solver. Mathematical discussions regarding the Cholesky decomposition are made with reference to the "Generalized Symmetric Definite Eigenproblems" chapter of Lapack. [Lapack]_ When operating in batch mode the zero valued padding columns and rows will result in the generation of "ghost" eigen-values/vectors. These are mostly harmless, but make it more difficult to extract the actual eigen-values/vectors. This function will move the "ghost" entities to the ends of their respective lists, making it easy to clip them out. Warnings -------- If operating upon zero-padded packed tensors then degenerate and zero valued eigen values will be encountered. This will **always** cause an error during the backwards pass unless broadening is enacted. As `torch.symeig` sorts its results prior to returning them, it is likely that any "ghost" eigen-values/vectors, which result from zero- padded packing, will be located in the middle of the returned arrays. This makes down-stream processing more challenging. Thus, the sort_out option is enabled by default. This results in the "ghost" values being moved to the end. **However**, this method identifies any entry with a zero-valued eigenvalue and an eigenvector which can be interpreted as a column of an identity matrix as a ghost. References ---------- .. [MS2019] Seeger, M., Hetzel, A., Dai, Z., & Meissner, E. Auto- Differentiating Linear Algebra. ArXiv:1710.08717 [Cs, Stat], Aug. 2019. arXiv.org, http://arxiv.org/abs/1710.08717. .. [LH2019] Liao, H.-J., Liu, J.-G., Wang, L., & Xiang, T. (2019). Differentiable Programming Tensor Networks. Physical Review X, 9(3). .. [Lapack] www.netlib.org/lapack/lug/node54.html (Accessed 21/04/2023) """ mask = None # satisfy type checker v: Tensor w: Tensor if __tversion__ < (2, 0, 0): # type: ignore[operator] # pragma: no cover _SymEigB = _SymEigBroad_V1 else: _SymEigB = _SymEigBroad_V2 # type: ignore[assignment] # Initial setup to make function calls easier to deal with # If smearing use _SymEigB otherwise use torch.linalg.eigh func: Callable = _SymEigB.apply if broadening_method is not None else torch.linalg.eigh # type: ignore[type-arg] # Set up for the arguments args = (broadening_method, factor) if broadening_method is not None else () if aux: is_zero = torch.eq(a, 0) mask = torch.all(is_zero, dim=-1) & torch.all(is_zero, dim=-2) if b is None: # For standard eigenvalue problem if aux and mask is not None: # Convert from zero-padding to padding with largest eigenvalue estimate shift = estimate_minmax(a)[-1].unsqueeze(-1) a = a + torch.diag_embed(shift * mask) w, v = func(a, *args) # Call the required eigen-solver else: # Otherwise it will be a general eigenvalue problem # Cholesky decomposition can only act on positive definite matrices; # which is problematic for zero-padded tensors. Similar issues are # encountered in the Löwdin scheme. To ensure positive definiteness # the diagonals of padding columns/rows are therefore set to 1. if is_posdef is False: # Create a mask which is True wherever a column/row pair is 0-valued is_zero = torch.eq(b, 0) mask = torch.all(is_zero, dim=-1) & torch.all(is_zero, dim=-2) # Set the diagonals at these locations to 1 b = b + torch.diag_embed(mask.type(a.dtype)) # For Cholesky decomposition scheme if scheme == "chol": # Perform Cholesky factorization (A = LL^{T}) of B to attain L l = torch.linalg.cholesky(b) # Compute the inverse of L: if kwargs.get("direct_inv", False): # Via the direct method if specifically requested l_inv = torch.inverse(l) else: # Otherwise compute via an indirect method (default) identity = torch.zeros_like(l) identity.diagonal(dim1=-2, dim2=-1)[:] = 1 l_inv = torch.linalg.solve(l, identity) # Transpose of l_inv: improves speed in batch mode l_inv_t = torch.transpose(l_inv, -1, -2) # To obtain C, perform the reduction operation C = L^{-1}AL^{-T} c = l_inv @ a @ l_inv_t if aux: # Convert from zero-padding to padding with largest eigenvalue estimate shift = estimate_minmax(c)[-1].unsqueeze(-1) c = c + torch.diag_embed(shift * mask) # The eigenvalues of Az = λBz are the same as Cy = λy; hence: w, v_ = func(c, *args) # Eigenvectors, however, are not, so they must be recovered: # z = L^{-T}y v = l_inv_t @ v_ elif scheme == "lowd": # For Löwdin Orthogonalisation scheme # Perform the BV = WV eigen decomposition. w, v = func(b, *args) # Embed w to construct "small b"; inverse power is also done here # to avoid inf values later on. b_small = torch.diag_embed(w**-0.5) # Construct symmetric orthogonalisation matrix via: # B^{-1/2} = V b^{-1/2} V^{T} b_so = v @ b_small @ v.transpose(-1, -2) # A' (a_prime) can then be constructed as: A' = B^{-1/2} A B^{-1/2} a_prime = b_so @ a @ b_so if aux: # Convert from zero-padding to padding with largest eigenvalue estimate shift = estimate_minmax(a_prime)[-1].unsqueeze(-1) a_prime = a_prime + torch.diag_embed(shift * mask) # Decompose the now orthogonalised A' matrix w, v_prime = func(a_prime, *args) # the correct eigenvector is then recovered via # V = B^{-1/2} V' v = b_so @ v_prime else: # If an unknown scheme was specified raise ValueError("Unknown scheme selected.") # If sort_out is enabled, nullify the "ghost" eigen-values if sort_out: if aux and mask is not None: w = torch.where( ~mask, w, torch.tensor(0, device=w.device, dtype=w.dtype), ) else: w, v = _eig_sort_out(w, v, not aux) # Return the eigenvalues and eigenvectors return w, v