SafeOps: Elementary Functions

SafeOps: Elementary Functions#

Safe versions of elementary functions like sqrt or abs.

tad_mctc.storch.elemental.divide(x, y, *, eps=None, **kwargs)[source]#

Safe divide operation. Only adds a small value to the denominator where it is zero.

Parameters:
  • x (Tensor) – Input tensor (nominator).

  • y (Tensor) – Input tensor (denominator).

  • eps (Tensor | float | int | None, optional) – Value added to the denominator. Defaults to None, which resolves to torch.finfo(x.dtype).eps.

Returns:

Square root of the input tensor.

Return type:

Tensor

Raises:

TypeError – Value for addition to denominator has wrong type.

tad_mctc.storch.elemental.pow(x, exponent, *, eps=None)[source]#

Takes the power of each element in input with exponent and returns a tensor with the result.

This is a safer version of torch.pow (out = x ** exponent), which avoids:

  1. NaN/imaginary output when x < 0 and exponent has a fractional part

    In this case, the function returns the signed (negative) magnitude of the complex number.

  2. NaN/infinite gradient at x = 0 when exponent has a fractional part

    In this case, the positions of 0 are added by epsilon, so the gradient is back-propagated as if x = epsilon.

However, this function doesn’t deal with float overflow, such as 1e10000.

Parameters:
  • x (torch.Tensor or float) – The input base value.

  • exponent (torch.Tensor or float) – The exponent value.

    (At least one of x and exponent must be a torch.Tensor)

  • epsilon (float) – A small floating point value to avoid infinite gradient. Default: 1e-6

Returns:

out – The output tensor.

Return type:

torch.Tensor

tad_mctc.storch.elemental.reciprocal(x, *, eps=None, **kwargs)[source]#

Safe reciprocal operation.

Parameters:
  • x (Tensor) – Input tensor (denominator).

  • eps (Tensor | float | int | None, optional) – Value added to the denominator. Defaults to None, which resolves to torch.finfo(x.dtype).eps.

Returns:

Reciprocal of the input tensor.

Return type:

Tensor

Raises:

TypeError – Value for addition to denominator has wrong type.

tad_mctc.storch.elemental.sqrt(x, *, eps=None)[source]#

Safe square root operation.

Parameters:
  • x (Tensor) – Input tensor.

  • eps (Tensor | float | int | None, optional) – Value for clamping. Defaults to None, which resolves to torch.finfo(x.dtype).eps.

Returns:

Square root of the input tensor.

Return type:

Tensor

Raises:

TypeError – Value for clamping has wrong type.