Source code for tad_mctc.molecule.property

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Molecule: Property
==================

Collection of functions for the calculation of molecular properties.
"""

from __future__ import annotations

import torch

from .. import storch, units
from ..batch import eye
from ..math import einsum
from ..typing import Tensor

__all__ = ["inertia_moment", "center_of_mass", "rot_consts"]


[docs] def center_of_mass(masses: Tensor, positions: Tensor) -> Tensor: """ Calculate the center of mass from the atomic coordinates and masses. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). Returns ------- Tensor Cartesian coordinates of center of mass of shape ``(..., 3)``. """ s = storch.reciprocal(torch.sum(masses, dim=-1)) return einsum("...z,...zx,...->...x", masses, positions, s)
def positions_rel_com(masses: Tensor, positions: Tensor) -> Tensor: """ Calculate positions relative to the center of mass. Parameters ---------- masses : Tensor Atomic masses for all atoms in the system (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). Returns ------- Tensor Cartesian coordinates relative to center of mass (shape: ``(..., nat, 3)``). """ com = center_of_mass(masses, positions) return positions - com.unsqueeze(-2)
[docs] def inertia_moment( masses: Tensor, positions: Tensor, center_pa: bool = True, pos_already_com: bool = False, ) -> Tensor: """ Calculate the inertia tensor of the molecule. Parameters ---------- masses : Tensor Atomic masses for all atoms in the system (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). center_pa : bool, optional If ``True``, the tensor is centered relative to the principal axes, which prepares for rotational analysis. Defaults to ``True``. pos_already_com : bool, optional If ``True``, the positions are already centered at the center of mass. Defaults to ``False``. Returns ------- Tensor Inertia tensor of shape ``(..., 3, 3)``. """ if pos_already_com is False: positions = positions_rel_com(masses, positions) im = einsum("...m,...mx,...my->...xy", masses, positions, positions) if center_pa is True: # trace = einsum("...ii->...", im) # einsum("...ij,...->...ij", eye(im), trace) - im) return einsum("...ij,...kk->...ij", eye(im.shape), im) - im return im
# TODO: Check against reference values # https://github.com/psi4/psi4/blob/3c2be0144a850eaad3b428ceabc58ff38a163fde/psi4/src/psi4/libmints/molecule.cc#L1353 # https://github.com/pyscf/pyscf/blob/master/pyscf/hessian/thermo.py#L111
[docs] def rot_consts(masses: Tensor, positions: Tensor) -> Tensor: # pragma: no cover r""" Calculate the rotational constants from the inertia tensor. .. math:: B = \frac{h}{8 \pi^2 c I} = \frac{\hbar}{4 \pi c I} Parameters ---------- masses : Tensor Atomic masses for all atoms in the system (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). Returns ------- Tensor Rotational constants of shape ``(..., 3)``. """ im = inertia_moment(masses, positions, center_pa=True) # Eigendecomposition yields the principal moments of inertia (w) # and the principal axes of rotation (_) of a molecule. w, _ = storch.eighb(im) # rotational constant in atomic units c_au = units.CODATA.c * (units.METER2AU / units.SECOND2AU) b = storch.reciprocal(4 * torch.pi * c_au * w) # hbar = 1 return torch.where(w > 1e-6, b, torch.zeros_like(b))