Source code for tad_mctc.convert.tensor
# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Conversion: Array/Tensor
========================
This module contains function for conversions of PyTorch tensors. This
includes, for example, reshaping.
Conversion into tensors from other data types (integer, float, etc.) is not
provided by this module.
"""
from __future__ import annotations
from functools import partial
import torch
from ..typing import Size, Tensor
__all__ = ["reshape_fortran", "symmetrize", "symmetrizef"]
[docs]
def reshape_fortran(x: Tensor, shape: Size) -> Tensor:
"""
Implements Fortran's `reshape` function (column-major).
Parameters
----------
x : Tensor
Input tensor.
shape : Size
Output size to which `x` is reshaped.
Returns
-------
Tensor
Reshaped tensor of size `shape`.
"""
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
[docs]
def symmetrize(x: Tensor, force: bool = False) -> Tensor:
"""
Symmetrize a tensor after checking if it is symmetric within a threshold.
Parameters
----------
x : Tensor
Tensor to check and symmetrize.
force : bool
Whether symmetry should be forced. This allows switching between actual
symmetrizing and only cleaning up numerical noise. Defaults to `False`.
Returns
-------
Tensor
Symmetrized tensor.
Raises
------
RuntimeError
If the tensor is not symmetric within the threshold.
"""
if x.ndim < 2:
raise RuntimeError("Only matrices and batches thereof allowed.")
if force is True:
return (x + x.mT) / 2.0
try:
atol = torch.finfo(x.dtype).eps * 10
except TypeError: # pragma: no cover
atol = 1e-5
if not torch.allclose(x, x.mT, atol=atol):
raise RuntimeError(
f"Matrix appears to be not symmetric (atol={atol:.3e}, "
f"dtype={x.dtype})."
)
return (x + x.mT) / 2.0
symmetrizef = partial(symmetrize, force=True)
"""Force symmetrization of a tensor."""