Source code for tad_mctc.convert.numpy

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Conversion: numpy
=================

This module contains safe functions for numpy and pytorch interconversion.
"""

from __future__ import annotations

import numpy as np
import torch
from numpy.typing import DTypeLike, NDArray

from .._version import __tversion__
from ..typing import Any, Tensor, get_default_dtype

__all__ = ["numpy_to_tensor", "tensor_to_numpy"]


numpy_to_torch_dtype_dict = {
    np.dtype(np.float16).type: torch.float16,
    np.dtype(np.float32).type: torch.float32,
    np.dtype(np.float64).type: torch.float64,
    np.dtype(np.int8).type: torch.int8,
    np.dtype(np.int16).type: torch.int16,
    np.dtype(np.int32).type: torch.int32,
    np.dtype(np.int64).type: torch.int64,
    np.dtype(np.uint8).type: torch.uint8,
}
"""Dict of NumPy dtype -> torch dtype (when the correspondence exists)"""

torch_to_numpy_dtype_dict: dict[torch.dtype, DTypeLike] = {
    torch.float16: np.dtype(np.float16),
    torch.float32: np.dtype(np.float32),
    torch.float64: np.dtype(np.float64),
    torch.int8: np.dtype(np.int8),
    torch.int16: np.dtype(np.int16),
    torch.int32: np.dtype(np.int32),
    torch.int64: np.dtype(np.int64),
    torch.uint8: np.dtype(np.uint8),
}
"""Dict of torch dtype -> NumPy dtype conversion"""


[docs] def numpy_to_tensor( x: NDArray[Any], device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> Tensor: """ Convert a numpy array to a PyTorch tensor. Parameters ---------- x : NDArray[Any] Array to convert. device : :class:`torch.device` | None, optional Device to store the tensor on. Defaults to `None`. dtype : :class:`torch.dtype` | None, optional Data type of the tensor. Defaults to `None`. Returns ------- Tensor Converted PyTorch tensor. """ if dtype is None: dtype = numpy_to_torch_dtype_dict.get(x.dtype.type, get_default_dtype()) assert dtype is not None return torch.from_numpy(x).type(dtype).to(device)
[docs] def tensor_to_numpy(x: Tensor, dtype: DTypeLike | None = None) -> NDArray[Any]: """ Convert a PyTorch tensor to a numpy array. Parameters ---------- x : Tensor Tensor to convert. dtype : np.dtype, optional Data type of the array. Defaults to `np.dtype(np.float64)`. Returns ------- np.ndarray Converted numpy array. """ if dtype is None: dtype = torch_to_numpy_dtype_dict.get(x.dtype, np.dtype(np.float64)) xdtype = torch_to_numpy_dtype_dict.get(x.dtype) x = x.detach().cpu() # pylint: disable=protected-access # see: https://github.com/pytorch/pytorch/issues/91810 if __tversion__ >= (1, 13, 0): if torch._C._functorch.is_gradtrackingtensor(x): while torch._C._functorch.is_functorch_wrapped_tensor(x) is True: x = torch._C._functorch.get_unwrapped(x) if __tversion__ < (2, 0, 0): # type: ignore[operator] # pragma: no cover interpreted = np.array(x.storage().tolist(), dtype=dtype) else: storage_bytes = bytes(x.untyped_storage()) # type: ignore interpreted = np.frombuffer(storage_bytes, dtype=xdtype).astype( dtype ) return interpreted.reshape(x.shape) _x: NDArray[Any] = x.numpy() return _x.astype(dtype)