Source code for tad_mctc.ncoord.d3

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Coordination number: DFT-D3
===========================

Calculation of coordination number for DFT-D3.
"""

from __future__ import annotations

import torch

from .. import storch
from ..batch import real_pairs
from ..data import radii
from ..typing import DD, Any, CountingFunction, Tensor
from . import defaults
from .common import coordination_number
from .count import dexp_count, exp_count

__all__ = ["cn_d3", "cn_d3_gradient"]


[docs] def cn_d3( numbers: Tensor, positions: Tensor, counting_function: CountingFunction = exp_count, ) -> Tensor: """ Compute the D3 fractional coordination number using the default parameters described in :mod:`tad_mctc.ncoord.defaults`. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). counting_function : CountingFunction, optional Counting function used for the DFT-D3 coordination number. Defaults to the exponential counting function :func:`tad_mctc.ncoord.count.exp_count`. Returns ------- Tensor Coordination numbers for all atoms (shape: ``(..., nat)``). """ dd: DD = {"device": positions.device, "dtype": positions.dtype} cutoff = torch.tensor(defaults.CUTOFF_D3, **dd) rcov = radii.COV_D3(**dd)[numbers] return coordination_number( numbers, positions, counting_function=counting_function, rcov=rcov, cutoff=cutoff, )
[docs] def cn_d3_gradient( numbers: Tensor, positions: Tensor, *, dcounting_function: CountingFunction = dexp_count, rcov: Tensor | None = None, cutoff: Tensor | None = None, **kwargs: Any, ) -> Tensor: """ Compute the derivative of the fractional coordination number with respect to atomic positions. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). dcounting_function : CountingFunction, optional Derivative of the counting function. Defaults to :func:`tad_mctc.ncoord.count.dexp_count`. rcov : Tensor | None, optional Covalent radii for each species. Defaults to ``None``. cutoff : Tensor | None, optional Real-space cutoff. Defaults to ``None``. kwargs : dict[str, Any] Pass-through arguments for counting function. For example, ``kcn``, the steepness of the counting function, which defaults to :data:`tad_mctc.ncoord.defaults.KCN_D3`. Returns ------- Tensor Coordination numbers for all atoms (shape: ``(..., nat, nat, 3)``). Raises ------ ValueError If shape mismatch between ``numbers``, ``positions`` and ``rcov`` is detected. """ dd: DD = {"device": positions.device, "dtype": positions.dtype} if cutoff is None: cutoff = torch.tensor(defaults.CUTOFF_D3, **dd) if rcov is None: rcov = radii.COV_D3(**dd)[numbers] else: rcov = rcov.to(**dd) if numbers.shape != rcov.shape: raise ValueError( f"Shape of covalent radii {rcov.shape} is not consistent with " f"({numbers.shape})." ) if numbers.shape != positions.shape[:-1]: raise ValueError( f"Shape of positions ({positions.shape[:-1]}) is not consistent " f"with atomic numbers ({numbers.shape})." ) eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd) mask = real_pairs(numbers, mask_diagonal=True) distances = torch.where(mask, storch.cdist(positions, positions, p=2), eps) rc = rcov.unsqueeze(-2) + rcov.unsqueeze(-1) dcf = torch.where( mask * (distances <= cutoff), dcounting_function(distances, rc, **kwargs), torch.tensor(0.0, **dd), ) # (..., nat, nat, 3) rij = positions.unsqueeze(-3) - positions.unsqueeze(-2) # (..., nat, nat, 1) * (..., nat, nat, 3) return (dcf / distances).unsqueeze(-1) * rij # "...ij,...ijx->...ijx"