SafeOps: Elementary Functions#
Safe versions of elementary functions like sqrt or abs.
- tad_mctc.storch.elemental.divide(x, y, *, eps=None, **kwargs)[source]#
Safe divide operation. Only adds a small value to the denominator where it is zero.
- Parameters:
x (Tensor) – Input tensor (nominator).
y (Tensor) – Input tensor (denominator).
eps (Tensor | float | int | None, optional) – Value added to the denominator. Defaults to None, which resolves to torch.finfo(x.dtype).eps.
- Returns:
Square root of the input tensor.
- Return type:
Tensor
- Raises:
TypeError – Value for addition to denominator has wrong type.
- tad_mctc.storch.elemental.pow(x, exponent, *, eps=None)[source]#
Takes the power of each element in input with exponent and returns a tensor with the result.
This is a safer version of
torch.pow(out = x ** exponent), which avoids:- NaN/imaginary output when
x < 0and exponent has a fractional part In this case, the function returns the signed (negative) magnitude of the complex number.
- NaN/imaginary output when
- NaN/infinite gradient at
x = 0when exponent has a fractional part In this case, the positions of 0 are added by
epsilon, so the gradient is back-propagated as ifx = epsilon.
- NaN/infinite gradient at
However, this function doesn’t deal with float overflow, such as 1e10000.
- Parameters:
x (torch.Tensor or float) – The input base value.
exponent (torch.Tensor or float) – The exponent value.
(At least one of
xandexponentmust be a torch.Tensor)epsilon (float) – A small floating point value to avoid infinite gradient. Default: 1e-6
- Returns:
out – The output tensor.
- Return type:
torch.Tensor
- tad_mctc.storch.elemental.reciprocal(x, *, eps=None, **kwargs)[source]#
Safe reciprocal operation.
- Parameters:
x (Tensor) – Input tensor (denominator).
eps (Tensor | float | int | None, optional) – Value added to the denominator. Defaults to None, which resolves to torch.finfo(x.dtype).eps.
- Returns:
Reciprocal of the input tensor.
- Return type:
Tensor
- Raises:
TypeError – Value for addition to denominator has wrong type.
- tad_mctc.storch.elemental.sqrt(x, *, eps=None)[source]#
Safe square root operation.
- Parameters:
x (Tensor) – Input tensor.
eps (Tensor | float | int | None, optional) – Value for clamping. Defaults to
None, which resolves totorch.finfo(x.dtype).eps.
- Returns:
Square root of the input tensor.
- Return type:
Tensor
- Raises:
TypeError – Value for clamping has wrong type.