Source code for tad_mctc.storch.elemental

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SafeOps: Elementary Functions
=============================

Safe versions of elementary functions like `sqrt` or `abs`.
"""

from __future__ import annotations

import torch

from ..typing import Any, Tensor
from .utils import get_eps

__all__ = ["divide", "pow", "reciprocal", "sqrt"]


[docs] def divide( x: Tensor, y: Tensor, *, eps: Tensor | float | int | None = None, **kwargs: Any, ) -> Tensor: """ Safe divide operation. Only adds a small value to the denominator where it is zero. Parameters ---------- x : Tensor Input tensor (nominator). y : Tensor Input tensor (denominator). eps : Tensor | float | int | None, optional Value added to the denominator. Defaults to `None`, which resolves to `torch.finfo(x.dtype).eps`. Returns ------- Tensor Square root of the input tensor. Raises ------ TypeError Value for addition to denominator has wrong type. """ if eps is None: eps = get_eps(x) elif isinstance(eps, (float, int)): eps = torch.tensor(eps, device=x.device, dtype=x.dtype) elif isinstance(eps, Tensor): eps = eps.to(device=x.device, dtype=x.dtype) else: raise TypeError( "Value for clamping must be None (default), Tensor, float, or int, " f"but {type(eps)} was given." ) y_safe = torch.where(y != 0, y, eps) return torch.divide(x, y_safe, **kwargs)
[docs] def reciprocal( x: Tensor, *, eps: Tensor | float | int | None = None, **kwargs: Any ) -> Tensor: """ Safe reciprocal operation. Parameters ---------- x : Tensor Input tensor (denominator). eps : Tensor | float | int | None, optional Value added to the denominator. Defaults to `None`, which resolves to `torch.finfo(x.dtype).eps`. Returns ------- Tensor Reciprocal of the input tensor. Raises ------ TypeError Value for addition to denominator has wrong type. """ if eps is None: eps = get_eps(x) elif isinstance(eps, (float, int)): eps = torch.tensor(eps, device=x.device, dtype=x.dtype) elif isinstance(eps, Tensor): eps = eps.to(device=x.device, dtype=x.dtype) else: raise TypeError( "Value for clamping must be None (default), Tensor, float, or int, " f"but {type(eps)} was given." ) one = torch.tensor(1.0, device=x.device, dtype=x.dtype) return torch.divide(one, x + eps, **kwargs)
[docs] def pow( x: Tensor, exponent: Tensor | float | int, *, eps: Tensor | float | int | None = None, ) -> Tensor: """ Takes the power of each element in input with exponent and returns a tensor with the result. This is a safer version of ``torch.pow`` (``out = x ** exponent``), which avoids: 1. NaN/imaginary output when ``x < 0`` and exponent has a fractional part In this case, the function returns the signed (negative) magnitude of the complex number. 2. NaN/infinite gradient at ``x = 0`` when exponent has a fractional part In this case, the positions of 0 are added by ``epsilon``, so the gradient is back-propagated as if ``x = epsilon``. However, this function doesn't deal with float overflow, such as 1e10000. Parameters ---------- x : torch.Tensor or float The input base value. exponent : torch.Tensor or float The exponent value. (At least one of ``x`` and ``exponent`` must be a torch.Tensor) epsilon : float A small floating point value to avoid infinite gradient. Default: 1e-6 Returns ------- out : torch.Tensor The output tensor. """ if eps is None: eps = get_eps(x) elif isinstance(eps, (float, int)): eps = torch.tensor(eps, device=x.device, dtype=x.dtype) elif isinstance(eps, Tensor): eps = eps.to(device=x.device, dtype=x.dtype) else: raise TypeError( "Value for clamping must be None (default), Tensor, float, or int, " f"but {type(eps)} was given." ) if (eps == 0).any(): raise ValueError( f"Value for clamping must be larger than 0.0, but {eps} was given." ) def _int(x: Tensor, exponent: int) -> Tensor: # integer positive exponents are safe if exponent > 0: return torch.pow(x, exponent) # integer negative exponents fail for x = 0 x = torch.where(x == 0, eps, x) return torch.pow(x, exponent) def _float(x: Tensor, exponent: float | Tensor) -> Tensor: # float positive exponents fail for x < 0 if exponent > 0: x = torch.where(x < 0, eps, x) return torch.pow(x, exponent) # float negative exponents fail for x <= 0 x = torch.where(x <= 0, eps, x) return torch.pow(x, exponent) if isinstance(exponent, int): return _int(x, exponent) if isinstance(exponent, float): if exponent.is_integer(): return _int(x, int(exponent)) return _float(x, exponent) if isinstance(exponent, Tensor): # integer positive exponents are safe if (exponent > 0).all() & (x >= 0).all(): return torch.pow(x, exponent) # float negative exponents fail for x <= 0 x = torch.where(x <= 0, eps, x) return torch.pow(x, exponent) raise ValueError( "Value for exponent must be integer, float, or Tensor, but " f"{type(exponent)} was given." )
[docs] def sqrt(x: Tensor, *, eps: Tensor | float | int | None = None) -> Tensor: """ Safe square root operation. Parameters ---------- x : Tensor Input tensor. eps : Tensor | float | int | None, optional Value for clamping. Defaults to ``None``, which resolves to ``torch.finfo(x.dtype).eps``. Returns ------- Tensor Square root of the input tensor. Raises ------ TypeError Value for clamping has wrong type. """ if eps is None: eps = get_eps(x) elif isinstance(eps, (float, int)): eps = torch.tensor(eps, device=x.device, dtype=x.dtype) elif isinstance(eps, Tensor): eps = eps.to(device=x.device, dtype=x.dtype) else: raise TypeError( "Value for clamping must be None (default), Tensor, float, or int, " f"but {type(eps)} was given." ) if eps < 0.0: raise ValueError( f"Value for clamping must be larger than 0.0, but {eps} was given." ) return torch.sqrt(torch.clamp(x, min=eps))