Autograd Utility: Checks#

Utility functions for checking properties of tensors in the context of automatic differentiation, such as whether a tensor is a grad tracking tensor, batched tensor, or a both (i.e., a “functorch” tensor).

tad_mctc.autograd.checks.is_batched(x)[source]#

Check if the input tensor is a batched tensor.

Only checks the first wrapper layer, i.e., grad-tracking tensors can obscure the batched nature of a tensor. Unwrap the tensor first to check the underlying tensor.

Note

Defaults to False for versions of PyTorch before 2.0.0.

Parameters:

x (Tensor) – The tensor to check.

Returns:

True if the tensor is a batched tensor, False otherwise.

Return type:

bool

tad_mctc.autograd.checks.is_functorch_tensor(x)[source]#

Check if the input tensor is a functorch tensor.

Note

Defaults to False for versions of PyTorch before 2.0.0.

Parameters:

x (Tensor) – The tensor to check.

Returns:

True if the tensor is a functorch tensor, False otherwise.

Return type:

bool

tad_mctc.autograd.checks.is_gradtracking(x)[source]#

Check if the input tensor is a grad tracking tensor.

Note

Defaults to False for versions of PyTorch before 2.0.0.

Parameters:

x (Tensor) – The tensor to check.

Returns:

True if the tensor is a grad tracking tensor, False otherwise.

Return type:

bool