# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SafeOps: Distance
=================
Functions for calculating the cartesian distance of two vectors.
"""
from __future__ import annotations
import torch
from ..math import einsum
from ..typing import Tensor
from .elemental import sqrt as ssqrt
__all__ = ["cdist"]
import torch
from ..math import einsum # as in your original file
from .elemental import sqrt as ssqrt
def euclidean_dist_quadratic_expansion(x: Tensor, y: Tensor) -> Tensor:
"""
Computation of euclidean distance matrix via quadratic expansion (sum of
squared differences or L2-norm of differences).
While this is significantly faster than the "direct expansion" or
"broadcast" approach, it only works for euclidean (p=2) distances.
Additionally, it has issues with numerical stability (the diagonal slightly
deviates from zero for ``x=y``). The numerical stability should not pose
problems, since we must remove zeros anyway for batched calculations.
For more information, see \
`this Jupyter notebook <https://github.com/eth-cscs/PythonHPC/blob/master/\
numpy/03-euclidean-distance-matrix-numpy.ipynb>`__ or \
`this discussion thread in the PyTorch forum <https://discuss.pytorch.org/\
t/efficient-distance-matrix-computation/9065>`__.
Parameters
----------
x : Tensor
First tensor.
y : Tensor
Second tensor (with same shape as first tensor).
Returns
-------
Tensor
Pair-wise distance matrix.
"""
eps = torch.tensor(
torch.finfo(x.dtype).eps,
device=x.device,
dtype=x.dtype,
)
# using einsum is slightly faster than `torch.pow(x, 2).sum(-1)`
xnorm = einsum("...ij,...ij->...i", x, x)
ynorm = einsum("...ij,...ij->...i", y, y)
# xnorm = (x ** 2).sum(-1)
# ynorm = (y ** 2).sum(-1)
n = xnorm.unsqueeze(-1) + ynorm.unsqueeze(-2)
# "...ik,...jk->...ij"
prod = x @ y.mT
# important: remove negative values that give NaN in backward
return ssqrt(n - 2.0 * prod, eps=eps)
def cdist_direct_expansion(x: Tensor, y: Tensor, p: int = 2) -> Tensor:
"""
Computation of cartesian distance matrix.
Contrary to `euclidean_dist_quadratic_expansion`, this function allows
arbitrary powers but is considerably slower.
Parameters
----------
x : Tensor
First tensor.
y : Tensor
Second tensor (with same shape as first tensor).
p : int, optional
Power used in the distance evaluation (p-norm). Defaults to 2.
Returns
-------
Tensor
Pair-wise distance matrix.
"""
eps = torch.tensor(
torch.finfo(x.dtype).eps,
device=x.device,
dtype=x.dtype,
)
# unsqueeze different dimension to create matrix
diff = torch.abs(x.unsqueeze(-2) - y.unsqueeze(-3))
# einsum is nearly twice as fast!
if p == 2:
distances = einsum("...ijk,...ijk->...ij", diff, diff)
else:
distances = torch.sum(torch.pow(diff, p), -1)
return torch.pow(torch.clamp(distances, min=eps), 1.0 / p)
[docs]
def cdist(x: Tensor, y: Tensor | None = None, p: int = 2) -> Tensor:
"""
Wrapper for cartesian distance computation.
This currently replaces the use of ``torch.cdist``, which does not handle
zeros well and produces nan's in the backward pass.
Additionally, ``torch.cdist`` does not return zero for distances between
same vectors (see `here
<https://github.com/pytorch/pytorch/issues/57690>`__).
Parameters
----------
x : Tensor
First tensor.
y : Tensor | None, optional
Second tensor. If no second tensor is given (default), the first tensor
is used as the second tensor, too.
p : int, optional
Power used in the distance evaluation (p-norm). Defaults to 2.
Returns
-------
Tensor
Pair-wise distance matrix.
"""
if y is None:
y = x
# faster
if p == 2:
return euclidean_dist_quadratic_expansion(x, y)
return cdist_direct_expansion(x, y, p=p)