Source code for tad_mctc.autograd.checks
# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Autograd Utility: Checks
========================
Utility functions for checking properties of tensors in the context of
automatic differentiation, such as whether a tensor is a grad tracking tensor,
batched tensor, or a both (i.e., a "functorch" tensor).
"""
from __future__ import annotations
import torch
from .._version import __tversion__
from ..typing import Tensor
__all__ = ["is_gradtracking", "is_batched", "is_functorch_tensor"]
[docs]
def is_gradtracking(x: Tensor) -> bool:
"""
Check if the input tensor is a grad tracking tensor.
Note
----
Defaults to ``False`` for versions of PyTorch before 2.0.0.
Parameters
----------
x : Tensor
The tensor to check.
Returns
-------
bool
``True`` if the tensor is a grad tracking tensor, ``False`` otherwise.
"""
if __tversion__ >= (2, 0, 0):
return torch._C._functorch.is_gradtrackingtensor(x)
return False
[docs]
def is_batched(x: Tensor) -> bool:
"""
Check if the input tensor is a batched tensor.
Only checks the first wrapper layer, i.e., grad-tracking tensors can
obscure the batched nature of a tensor. Unwrap the tensor first to check
the underlying tensor.
Note
----
Defaults to ``False`` for versions of PyTorch before 2.0.0.
Parameters
----------
x : Tensor
The tensor to check.
Returns
-------
bool
``True`` if the tensor is a batched tensor, ``False`` otherwise.
"""
if __tversion__ >= (2, 0, 0):
return torch._C._functorch.is_batchedtensor(x)
return False
[docs]
def is_functorch_tensor(x: Tensor) -> bool:
"""
Check if the input tensor is a functorch tensor.
Note
----
Defaults to ``False`` for versions of PyTorch before 2.0.0.
Parameters
----------
x : Tensor
The tensor to check.
Returns
-------
bool
``True`` if the tensor is a functorch tensor, ``False`` otherwise.
"""
if __tversion__ >= (2, 0, 0):
return is_gradtracking(x) or is_batched(x)
return False