Source code for tad_mctc.autograd.hessian
# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Autograd Utility: Hessian
=========================
Utilities for calculating Hessians via autograd.
Note
----
Batched Hessians are not supported yet (via `vmap`).
"""
from __future__ import annotations
import torch
from ..typing import Any, Callable, Tensor
from .compat import jacrev_compat as jacrev
__all__ = ["hessian", "hess_fn_rev"]
[docs]
def hessian(
f: Callable[..., Tensor],
inputs: tuple[Any, ...],
argnums: int,
is_batched: bool = False,
**kwargs: Any,
) -> Tensor:
"""
Wrapper for Hessian. The Hessian is the Jacobian of the gradient.
PyTorch, however, suggests calculating the Jacobian of the Jacobian, which
does not yield the correct shape in this case.
Parameters
----------
f : Callable[[Any], Tensor]
The function whose result is differentiated.
inputs : tuple[Any, ...]
The input parameters of `f`.
argnums : int, optional
The variable w.r.t. which will be differentiated. Defaults to 0.
Returns
-------
Tensor
The Hessian.
Raises
------
RuntimeError
The parameter selected for differentiation (via `argnums`) is not a
tensor.
"""
if not isinstance(inputs[argnums], Tensor):
raise ValueError(
f"The {argnums}'th input parameter must be a tensor but is of "
f"type '{type(inputs[argnums])}'."
)
def _grad(*inps: tuple[Any, ...]) -> Tensor:
e = f(*inps).sum()
# catch missing gradients
if e.grad_fn is None:
return torch.zeros_like(inps[argnums]) # type: ignore
(g,) = torch.autograd.grad(
e,
inps[argnums],
create_graph=True,
)
return g
_jac = jacrev(_grad, argnums=argnums, **kwargs)
if is_batched:
raise NotImplementedError("Batched Hessian not available.")
# dims = tuple(None if x != argnums else 0 for x in range(len(inputs)))
# _jac = torch.func.vmap(_jac, in_dims=dims)
return _jac(*inputs) # type: ignore
[docs]
def hess_fn_rev(
f: Callable[..., Tensor], argnums: tuple[int] | int = 0
) -> Callable:
"""
Return the Hessian function using reverse-mode autodiff twice.
(Functorch's `hessian` uses forward and backward mode, but forward is
not implemented for our custom autograd functions.)
Parameters
----------
f : Callable[[Any], Tensor]
The function whose result is differentiated.
argnums : int or tuple[int], optional
The variable w.r.t. which will be differentiated. Defaults to 0.
Returns
-------
Callable
A function that computes the Hessian of `f` with respect to the
specified argument(s).
"""
return torch.func.jacrev(
torch.func.jacrev(f, argnums=argnums), argnums=argnums
)