Source code for tad_mctc.io.checks.shape

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
I/O Checks: Shape
=================

This module contains shape checkers for the inputs passed to the reader/writer.
"""

from __future__ import annotations

from ...typing import Tensor

__all__ = ["shape_checks"]


[docs] def shape_checks( numbers: Tensor, positions: Tensor, allow_batched: bool = True ) -> bool: """ Check the shapes of the numbers and positions tensors. This explicitly checks for non-batched tensor shapes (batched tensors throw errors). Parameters ---------- numbers : Tensor A 1D tensor containing atomic numbers or symbols. positions : Tensor A 2D tensor of shape ``(nat, 3)`` containing atomic positions. Returns ------- bool True if the shapes are correct. Raises ------ ValueError If the shapes of both tensors are inconsistent, the last dimension of the positions tensor is not 3 (cartesian directions), the numbers tensor has not one dimension, or the positions tensor has not two dimensions. """ if numbers.shape != positions.shape[:-1]: raise ValueError( f"Shape of positions ({positions.shape[:-1]}) is not consistent " f"with atomic numbers ({numbers.shape})." ) if positions.shape[-1] != 3: raise ValueError( "The last dimension of the position tensor must present the " "cartesian directions, i.e., it must be size 3 (but is " f"{positions.shape[-1]}" ) if allow_batched is False: if len(numbers.shape) != 1 or len(positions.shape) != 2: raise ValueError( "Invalid shape for tensors (batched tensors not allowed)." ) return True