Source code for tad_mctc.io.read.frompath

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
I/O Read: From Path
===================

A convenience function to create readers that take a path instead of a stream.

Example
-------
>>> from tad_mctc.io import read
>>> read_xyz = read.create_path_reader(read.read_xyz_fileobj)
>>> path = ...
>>> numbers, positions = read_xyz(path)
"""

from __future__ import annotations

from pathlib import Path
from typing import runtime_checkable

import torch

from ...typing import IO, Any, Literal, PathLike, Protocol, Tensor

__all__ = ["create_path_reader", "create_path_reader_dotfiles"]


@runtime_checkable
class ReaderFunction(Protocol):
    """Type annotation for a reader function."""

    def __call__(
        self,
        fileobj: IO[Any],
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> Tensor | tuple[Tensor, Tensor]: ...


@runtime_checkable
class FileReaderFunction(Protocol):
    """Type annotation for a file reader function."""

    def __call__(
        self,
        filepath: PathLike,
        mode: str = "r",
        encoding: str = "utf-8",
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
        **kwargs: Any,
    ) -> Tensor | tuple[Tensor, Tensor]:
        """
        Reads the file from the specified path.

        Parameters
        ----------
        file : PathLike
            Path of file containing the structure.
        mode : str, optional
            Mode in which the file is opened. Defaults to ``"r"``.
        encoding : str, optional
            Encoding for file. Defaults to ``"utf-8"``.
        device : :class:`torch.device` | None, optional
            Device to store the tensor on. Defaults to ``None``.
        dtype : :class:`torch.dtype` | None, optional
            Floating point data type of the tensor. Defaults to ``None``.

        Returns
        -------
        Tensor | tuple[Tensor, Tensor]
            Returned tensor or tensors.

        Raises
        ------
        FileNotFoundError
            The file specified in ``filepath`` cannot be found.
        """
        ...


[docs] def create_path_reader(reader_function: ReaderFunction) -> FileReaderFunction: """ Creates a function that reads data from a specified file path using a given reader function. Parameters ---------- reader_function : ReaderFunction The function used to read and process the file contents. Returns ------- FileReaderFunction A function that takes a file path, mode, device, and dtype, and returns the processed data. """ def read_from_path( filepath: PathLike, mode: str = "r", encoding: str = "utf-8", device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> Tensor | tuple[Tensor, Tensor]: """ Reads the file from the specified path. Parameters ---------- file : PathLike Path of file containing the structure. mode : str, optional Mode in which the file is opened. Defaults to ``"r"``. encoding : str, optional Encoding for file. Defaults to ``"utf-8"``. device : :class:`torch.device` | None, optional Device to store the tensor on. Defaults to ``None``. dtype : :class:`torch.dtype` | None, optional Floating point data type of the tensor. Defaults to ``None``. Returns ------- Tensor | tuple[Tensor, Tensor] Returned tensor or tensors. Raises ------ FileNotFoundError The file specified in ``filepath`` cannot be found. """ path = Path(filepath) # Check if the file exists if not path.exists(): raise FileNotFoundError(f"The file '{path}' does not exist.") with open(path, mode=mode, encoding=encoding) as fileobj: return reader_function(fileobj, device, dtype, **kwargs) return read_from_path
################################################################################ @runtime_checkable class ReaderFunctionTensor(Protocol): """Type annotation for a reader function that returns a tensor.""" def __call__( self, fileobj: IO[Any], device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> Tensor: ... @runtime_checkable class FileReaderFunctionTensor(Protocol): """Type annotation for a file reader function that returns a tensor.""" def __call__( self, filepath: PathLike, mode: str = "r", encoding: str = "utf-8", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> Tensor: """ Reads the file from the specified path. Parameters ---------- file : PathLike Path of file containing the structure. mode : str, optional Mode in which the file is opened. Defaults to ``"r"``. encoding : str, optional Encoding for file. Defaults to ``"utf-8"``. device : :class:`torch.device` | None, optional Device to store the tensor on. Defaults to ``None``. dtype : :class:`torch.dtype` | None, optional Floating point data type of the tensor. Defaults to ``None``. Returns ------- Tensor Value stored in the file as tensor. """ ...
[docs] def create_path_reader_dotfiles( reader_function: ReaderFunctionTensor, name: Literal[".CHRG", ".UHF"] ) -> FileReaderFunctionTensor: """ Creates a function that reads data from a specified file path using a given reader function. Parameters ---------- reader_function : ReaderFunction The function used to read and process the file contents. name: Literal[".CHRG", ".UHF"] Name of the dotfile to be read. Returns ------- FileReaderFunction A function that takes a file path, mode, device, and dtype, and returns the processed data. """ # return default if file is not found (must be integer to allow integer # dtypes from PyTorch, e.g., 0.0 fails with torch.long) default_value = 0 def read_from_path( filepath: PathLike, mode: str = "r", encoding: str = "utf-8", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> Tensor: """ Reads the file from the specified path. Parameters ---------- file : PathLike Path of file containing the structure. mode : str, optional Mode in which the file is opened. Defaults to ``"r"``. encoding : str, optional Encoding for file. Defaults to ``"utf-8"``. device : :class:`torch.device` | None, optional Device to store the tensor on. Defaults to ``None``. dtype : :class:`torch.dtype` | None, optional Floating point data type of the tensor. Defaults to ``None``. Returns ------- Tensor Value stored in the file as tensor. """ path = Path(filepath) # possibly coordinate file given -> search dotfile in same directory if path.is_file(): if path.name not in (".CHRG", ".UHF"): path = path.parent / name if path.is_dir(): path = path / name # Check if the file now exists if not path.exists(): return torch.tensor(default_value, device=device, dtype=dtype) with open(path, mode=mode, encoding=encoding) as fileobj: return reader_function(fileobj, device, dtype) return read_from_path