Source code for tad_mctc.molecule.bond

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Molecule: Bonds
===============

Module for estimating or guessing bond orders between atoms. This module uses
a geometric model to obtain the optimal bond distance between a pair of atoms,
which is compared with the actual distance between the atoms to obtain the
liklihood of the bond being present.

Based on S. Spicher and S. Grimme, *Angew. Chem. Int. Ed.*, **2020**, 59,
15665–15673 (`DOI <https://doi.org10.1002/anie.202004239>`__).

Example
-------
>>> import torch
>>> from tad_mctc.molecule.bond import guess_bond_order
>>> numbers = torch.tensor([7, 7, 1, 1, 1, 1, 1, 1])
>>> positions = torch.tensor([
...     [-2.98334550857544, -0.08808205276728, +0.00000000000000],
...     [+2.98334550857544, +0.08808205276728, +0.00000000000000],
...     [-4.07920360565186, +0.25775116682053, +1.52985656261444],
...     [-1.60526800155640, +1.24380481243134, +0.00000000000000],
...     [-4.07920360565186, +0.25775116682053, -1.52985656261444],
...     [+4.07920360565186, -0.25775116682053, -1.52985656261444],
...     [+1.60526800155640, -1.24380481243134, +0.00000000000000],
...     [+4.07920360565186, -0.25775116682053, +1.52985656261444],
... ])
>>> cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
>>> bond_order = guess_bond_order(numbers, positions, cn)
>>> print(bond_order)
tensor([[0.0000, 0.0000, 0.4403, 0.4334, 0.4403, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4403, 0.4334, 0.4403],
        [0.4403, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4334, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4403, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4403, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4334, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4403, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
>>> print(bond_order > 0.3)
tensor([[False, False,  True,  True,  True, False, False, False],
        [False, False, False, False, False,  True,  True,  True],
        [ True, False, False, False, False, False, False, False],
        [ True, False, False, False, False, False, False, False],
        [ True, False, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False]])
"""

from __future__ import annotations

import torch

from .. import storch
from ..batch import real_pairs
from ..ncoord import defaults, erf_count
from ..typing import DD, Any, CountingFunction, Tensor

__all__ = ["guess_bond_length", "guess_bond_order"]


_en = torch.tensor(
    [
        *[+0.00000000],
        *[+2.30085633, +2.78445145, +1.52956084, +1.51714704, +2.20568300],
        *[+2.49640820, +2.81007174, +4.51078438, +4.67476223, +3.29383610],
        *[+2.84505365, +2.20047950, +2.31739628, +2.03636974, +1.97558064],
        *[+2.13446570, +2.91638164, +1.54098156, +2.91656301, +2.26312147],
        *[+2.25621439, +1.32628677, +2.27050569, +1.86790977, +2.44759456],
        *[+2.49480042, +2.91545568, +3.25897750, +2.68723778, +1.86132251],
        *[+2.01200832, +1.97030722, +1.95495427, +2.68920990, +2.84503857],
        *[+2.61591858, +2.64188286, +2.28442252, +1.33011187, +1.19809388],
        *[+1.89181390, +2.40186898, +1.89282464, +3.09963488, +2.50677823],
        *[+2.61196704, +2.09943450, +2.66930105, +1.78349472, +2.09634533],
        *[+2.00028974, +1.99869908, +2.59072029, +2.54497829, +2.52387890],
        *[+2.30204667, +1.60119300, +2.00000000, +2.00000000, +2.00000000],
        *[+2.00000000, +2.00000000, +2.00000000, +2.00000000, +2.00000000],
        *[+2.00000000, +2.00000000, +2.00000000, +2.00000000, +2.00000000],
        *[+2.00000000, +2.30089349, +1.75039077, +1.51785130, +2.62972945],
        *[+2.75372921, +2.62540906, +2.55860939, +3.32492356, +2.65140898],
        *[+1.52014458, +2.54984804, +1.72021963, +2.69303422, +1.81031095],
        *[+2.34224386],
    ]
)
"""
Electronegativity parameter used to determine polarity of bonds.
"""

_r0 = torch.tensor(
    [
        *[+0.00000000],
        *[+0.55682207, +0.80966997, +2.49092101, +1.91705642, +1.35974851],
        *[+0.98310699, +0.98423007, +0.76716063, +1.06139799, +1.17736822],
        *[+2.85570926, +2.56149012, +2.31673425, +2.03181740, +1.82568535],
        *[+1.73685958, +1.97498207, +2.00136196, +3.58772537, +2.68096221],
        *[+2.23355957, +2.33135502, +2.15870365, +2.10522128, +2.16376162],
        *[+2.10804037, +1.96460045, +2.00476257, +2.22628712, +2.43846700],
        *[+2.39408483, +2.24245792, +2.05751204, +2.15427677, +2.27191920],
        *[+2.19722638, +3.80910350, +3.26020971, +2.99716916, +2.71707818],
        *[+2.34950167, +2.11644818, +2.47180659, +2.32198800, +2.32809515],
        *[+2.15244869, +2.55958313, +2.59141300, +2.62030465, +2.39935278],
        *[+2.56912355, +2.54374096, +2.56914830, +2.53680807, +4.24537037],
        *[+3.66542289, +3.19903011, +2.80000000, +2.80000000, +2.80000000],
        *[+2.80000000, +2.80000000, +2.80000000, +2.80000000, +2.80000000],
        *[+2.80000000, +2.80000000, +2.80000000, +2.80000000, +2.80000000],
        *[+2.80000000, +2.34880037, +2.37597108, +2.49067697, +2.14100577],
        *[+2.33473532, +2.19498900, +2.12678348, +2.34895048, +2.33422774],
        *[+2.86560827, +2.62488837, +2.88376127, +2.75174124, +2.83054552],
        *[+2.63264944],
    ]
)
"""
Van-der-Waals radii for each element.
"""

_cf = torch.tensor(
    [
        *[+0.00000000],
        *[+0.17957827, +0.25584045, -0.02485871, +0.00374217, +0.05646607],
        *[+0.10514203, +0.09753494, +0.30470380, +0.23261783, +0.36752208],
        *[+0.00131819, -0.00368122, -0.01364510, +0.04265789, +0.07583916],
        *[+0.08973207, -0.00589677, +0.13689929, -0.01861307, +0.11061699],
        *[+0.10201137, +0.05426229, +0.06014681, +0.05667719, +0.02992924],
        *[+0.03764312, +0.06140790, +0.08563465, +0.03707679, +0.03053526],
        *[-0.00843454, +0.01887497, +0.06876354, +0.01370795, -0.01129196],
        *[+0.07226529, +0.01005367, +0.01541506, +0.05301365, +0.07066571],
        *[+0.07637611, +0.07873977, +0.02997732, +0.04745400, +0.04582912],
        *[+0.10557321, +0.02167468, +0.05463616, +0.05370913, +0.05985441],
        *[+0.02793994, +0.02922983, +0.02220438, +0.03340460, -0.04110969],
        *[-0.01987240, +0.07260201, +0.07700000, +0.07700000, +0.07700000],
        *[+0.07700000, +0.07700000, +0.07700000, +0.07700000, +0.07700000],
        *[+0.07700000, +0.07700000, +0.07700000, +0.07700000, +0.07700000],
        *[+0.07700000, +0.08379100, +0.07314553, +0.05318438, +0.06799334],
        *[+0.04671159, +0.06758819, +0.09488437, +0.07556405, +0.13384502],
        *[+0.03203572, +0.04235009, +0.03153769, -0.00152488, +0.02714675],
        *[+0.04800662],
    ]
)
"""
Coordination number based scaling factor.
"""

_ir = torch.tensor(
    [0] + 2 * [1] + 8 * [2] + 8 * [3] + 18 * [4] + 18 * [5] + 32 * [6]
)
"""
Row index in the periodic table
"""

_p1 = (
    0.01
    * torch.tensor(
        [0.0, 29.84522887, -1.70549806, 6.54013762, 6.39169003, 6.00, 5.6],
    )[_ir]
)
"""
Polynomial parameters for contributions linear in the electronegativity
difference.
"""

_p2 = (
    0.01
    * torch.tensor(
        [0.0, -8.87843763, 2.10878369, 0.08009374, -0.85808076, -1.15, -1.3],
    )[_ir]
)
"""
Polynomial parameters for contributions quadratic in the electronegativity
difference.
"""


[docs] def guess_bond_length(numbers: Tensor, cn: Tensor) -> Tensor: """ Estimate equilibrium bond lengths using a geometric model. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system. cn : Tensor Coordination numbers for all atoms in the system. Returns ------- Tensor Estimated bond lengths for all atom pairs Example ------- >>> import torch >>> from tad_mctc.molecule.bond import guess_bond_length >>> numbers = torch.tensor([6, 8, 7, 1, 1, 1]) >>> cn = torch.tensor([3.0059586, 1.0318390, 3.0268824, 1.0061584, 1.0036336, 0.9989871]) >>> print(guess_bond_length(numbers, cn)) tensor([[2.5983, 2.2588, 2.5871, 1.9833, 1.9828, 1.9820], [2.2588, 2.1631, 2.2855, 1.5542, 1.5538, 1.5531], [2.5871, 2.2855, 2.5589, 1.8902, 1.8897, 1.8890], [1.9833, 1.5542, 1.8902, 1.4750, 1.4746, 1.4737], [1.9828, 1.5538, 1.8897, 1.4746, 1.4741, 1.4733], [1.9820, 1.5531, 1.8890, 1.4737, 1.4733, 1.4724]]) """ dd: DD = {"device": cn.device, "dtype": cn.dtype} r0 = _r0.to(**dd)[numbers] en = _en.to(**dd)[numbers] cf = _cf.to(**dd)[numbers] p1 = _p1.to(**dd)[numbers] p2 = _p2.to(**dd)[numbers] ratom = r0 + cf * cn ediff = torch.abs(en.unsqueeze(-1) - en.unsqueeze(-2)) scale = ( torch.ones(ediff.shape, **dd) - (p1.unsqueeze(-1) + p1.unsqueeze(-2)) / 2 * ediff - (p2.unsqueeze(-1) + p2.unsqueeze(-2)) / 2 * ediff**2 ) return scale * (ratom.unsqueeze(-1) + ratom.unsqueeze(-2))
[docs] def guess_bond_order( numbers: Tensor, positions: Tensor, cn: Tensor, counting_function: CountingFunction = erf_count, **kwargs: Any, ) -> Tensor: """ Try to guess whether an atom pair is bonded using a geometric criterium. This measure is based on model taking into account the coordination number of each atom as well as the polarity of the bond. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). cn : Tensor Coordination numbers for all atoms (shape: ``(..., nat)``). counting_function : CountingFunction Function to determine whether two atoms are bonded, additional arguments are passed to the counting function. kcn : float, optional Steepness of the counting function. Returns ------- Tensor Bond order for all atom pairs. Example ------- >>> import torch >>> from tad_mctc.molecule.bond import guess_bond_order >>> from tad_mctc.batch import pack >>> numbers = pack(( ... torch.tensor([7, 1, 1, 1]), ... torch.tensor([6, 8, 8, 1, 1]), ... )) >>> positions = pack(( ... torch.tensor([ ... [+0.00000000000000, +0.00000000000000, -0.54524837997150], ... [-0.88451840382282, +1.53203081565085, +0.18174945999050], ... [-0.88451840382282, -1.53203081565085, +0.18174945999050], ... [+1.76903680764564, +0.00000000000000, +0.18174945999050], ... ]), ... torch.tensor([ ... [-0.53424386915034, -0.55717948166537, +0.00000000000000], ... [+0.21336223456096, +1.81136801357186, +0.00000000000000], ... [+0.82345103924195, -2.42214694643037, +0.00000000000000], ... [-2.59516465056138, -0.70672678063558, +0.00000000000000], ... [+2.09259524590881, +1.87468519515944, +0.00000000000000], ... ]), ... )) >>> cn = torch.tensor([ ... [2.9901006, 0.9977214, 0.9977214, 0.9977214, 0.0000000], ... [3.0093639, 2.0046251, 1.0187057, 0.9978270, 1.0069743], ... ]) >>> bond_order = guess_bond_order(numbers, positions, cn) >>> print(bond_order[0, ...]) tensor([[0.0000, 0.4392, 0.4392, 0.4392, 0.0000], [0.4392, 0.0000, 0.0000, 0.0000, 0.0000], [0.4392, 0.0000, 0.0000, 0.0000, 0.0000], [0.4392, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]) >>> print(bond_order[1, ...]) tensor([[0.0000, 0.5935, 0.4043, 0.3262, 0.0000], [0.5935, 0.0000, 0.0000, 0.0000, 0.3347], [0.4043, 0.0000, 0.0000, 0.0000, 0.0000], [0.3262, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.3347, 0.0000, 0.0000, 0.0000]]) """ dd: DD = {"device": positions.device, "dtype": positions.dtype} mask = real_pairs(numbers, True) distances = torch.where( mask, storch.cdist(positions, positions, p=2), torch.tensor(torch.finfo(positions.dtype).eps, **dd), ) if "kcn" not in kwargs: kwargs["kcn"] = defaults.KCN_EEQ bond_length = guess_bond_length(numbers, cn) return torch.where( mask, counting_function(distances, bond_length, **kwargs), torch.tensor(0.0, **dd), )