Source code for tad_mctc.ncoord.common

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Coordination number: Common helpers
===================================

Generic building blocks that mirror the container-based design of mctc-lib.
"""

from __future__ import annotations

import torch

from .. import storch
from ..batch import real_pairs
from ..data import radii
from ..typing import DD, Any, CountingFunction, Tensor
from . import defaults

__all__ = ["coordination_number", "cut_coordination_number"]


[docs] def coordination_number( numbers: Tensor, positions: Tensor, *, counting_function: CountingFunction, rcov: Tensor | None = None, cutoff: Tensor | float | int | None = None, cn_max: Tensor | float | int | None = None, pair_weight: Tensor | None = None, **kwargs: Any, ) -> Tensor: """ Generic coordination number evaluator. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates (shape: ``(..., nat, 3)``). counting_function : CountingFunction Pair counting function (exp, erf, dexp, ...). rcov : Tensor | None, optional Covalent radii for each species. cutoff : Tensor | float | int Real-space cutoff. cn_max : Tensor | float | int | None, optional Optional upper bound for the coordination number. pair_weight : Tensor | None, optional Optional per-pair weighting factor applied to the counting function. kwargs : dict[str, Any] Additional keyword arguments forwarded to the counting function. Returns ------- Tensor Coordination numbers for all atoms (shape: ``(..., nat)``). """ dd: DD = {"device": positions.device, "dtype": positions.dtype} if rcov is None: rcov = radii.COV_D3(**dd)[numbers] else: rcov = rcov.to(**dd) if numbers.shape != rcov.shape: raise ValueError( f"Shape of covalent radii {rcov.shape} is not consistent with " f"({numbers.shape})." ) if numbers.shape != positions.shape[:-1]: raise ValueError( f"Shape of positions ({positions.shape[:-1]}) is not consistent " f"with atomic numbers ({numbers.shape})." ) mask = real_pairs(numbers, mask_diagonal=True) eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd) distances = torch.where(mask, storch.cdist(positions, positions, p=2), eps) rc = rcov.unsqueeze(-2) + rcov.unsqueeze(-1) counts = counting_function(distances, rc, **kwargs) if pair_weight is not None: counts = pair_weight.to(**dd) * counts if cutoff is not None: cutoff_tensor = ( cutoff.to(**dd) if isinstance(cutoff, torch.Tensor) else torch.tensor(cutoff, **dd) ) valid = mask & (distances <= cutoff_tensor) else: valid = mask zero = torch.tensor(0.0, **dd) cf = torch.where(valid, counts, zero) cn = torch.sum(cf, dim=-1) if cn_max is None: return cn return cut_coordination_number(cn, cn_max)
[docs] def cut_coordination_number( cn: Tensor, cn_max: Tensor | float | int = defaults.CUTOFF_EEQ_MAX ) -> Tensor: """ Apply the smooth logarithmic cutoff used throughout mctc projects. Parameters ---------- cn : Tensor Coordination numbers. cn_max : Tensor | float | int, optional Maximum coordination number. Large values disable the cutoff. Returns ------- Tensor Cut coordination numbers. """ if isinstance(cn_max, torch.Tensor): cn_max_tensor = cn_max.to(device=cn.device, dtype=cn.dtype) else: cn_max_tensor = torch.tensor(cn_max, device=cn.device, dtype=cn.dtype) if torch.all(cn_max_tensor > 50): return cn return torch.log1p(torch.exp(cn_max_tensor)) - torch.log1p( torch.exp(cn_max_tensor - cn) )