# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
Tuple,
)
import numpy as np
from deepmd.env import (
GLOBAL_TF_FLOAT_PRECISION,
tf,
)
from deepmd.utils.spin import (
Spin,
)
# from deepmd.descriptor import DescrptLocFrame
# from deepmd.descriptor import DescrptSeA
# from deepmd.descriptor import DescrptSeT
# from deepmd.descriptor import DescrptSeAEbd
# from deepmd.descriptor import DescrptSeAEf
# from deepmd.descriptor import DescrptSeR
from .descriptor import (
Descriptor,
)
[docs]@Descriptor.register("hybrid")
class DescrptHybrid(Descriptor):
"""Concate a list of descriptors to form a new descriptor.
Parameters
----------
list : list
Build a descriptor from the concatenation of the list of descriptors.
"""
def __init__(
self,
list: list,
multi_task: bool = False,
ntypes: Optional[int] = None,
spin: Optional[Spin] = None,
**kwargs,
) -> None:
"""Constructor."""
# warning: list is conflict with built-in list
descrpt_list = list
if descrpt_list == [] or descrpt_list is None:
raise RuntimeError(
"cannot build descriptor from an empty list of descriptors."
)
formatted_descript_list = []
self.multi_task = multi_task
for ii in descrpt_list:
if isinstance(ii, Descriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
formatted_descript_list.append(
Descriptor(**ii, ntypes=ntypes, spin=spin, multi_task=multi_task)
)
else:
raise NotImplementedError
self.descrpt_list = formatted_descript_list
self.numb_descrpt = len(self.descrpt_list)
for ii in range(1, self.numb_descrpt):
assert (
self.descrpt_list[ii].get_ntypes() == self.descrpt_list[0].get_ntypes()
), f"number of atom types in {ii}th descrptor does not match others"
[docs] def get_rcut(self) -> float:
"""Returns the cut-off radius."""
all_rcut = [ii.get_rcut() for ii in self.descrpt_list]
return np.max(all_rcut)
[docs] def get_ntypes(self) -> int:
"""Returns the number of atom types."""
return self.descrpt_list[0].get_ntypes()
[docs] def get_dim_out(self) -> int:
"""Returns the output dimension of this descriptor."""
all_dim_out = [ii.get_dim_out() for ii in self.descrpt_list]
return sum(all_dim_out)
[docs] def get_nlist(
self,
) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]:
"""Get the neighbor information of the descriptor, returns the
nlist of the descriptor with the largest cut-off radius.
Returns
-------
nlist
Neighbor list
rij
The relative distance between the neighbor and the center atom.
sel_a
The number of neighbors with full information
sel_r
The number of neighbors with only radial information
"""
maxr_idx = np.argmax([ii.get_rcut() for ii in self.descrpt_list])
return self.get_nlist_i(maxr_idx)
[docs] def get_nlist_i(self, ii: int) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]:
"""Get the neighbor information of the ii-th descriptor.
Parameters
----------
ii : int
The index of the descriptor
Returns
-------
nlist
Neighbor list
rij
The relative distance between the neighbor and the center atom.
sel_a
The number of neighbors with full information
sel_r
The number of neighbors with only radial information
"""
return (
self.descrpt_list[ii].nlist,
self.descrpt_list[ii].rij,
self.descrpt_list[ii].sel_a,
self.descrpt_list[ii].sel_r,
)
[docs] def build(
self,
coord_: tf.Tensor,
atype_: tf.Tensor,
natoms: tf.Tensor,
box_: tf.Tensor,
mesh: tf.Tensor,
input_dict: dict,
reuse: Optional[bool] = None,
suffix: str = "",
) -> tf.Tensor:
"""Build the computational graph for the descriptor.
Parameters
----------
coord_
The coordinate of atoms
atype_
The type of atoms
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box_ : tf.Tensor
The box of the system
mesh
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.
input_dict
Dictionary for additional inputs
reuse
The weights in the networks should be reused when get the variable.
suffix
Name suffix to identify this descriptor
Returns
-------
descriptor
The output descriptor
"""
with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse):
t_rcut = tf.constant(
self.get_rcut(), name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION
)
t_ntypes = tf.constant(self.get_ntypes(), name="ntypes", dtype=tf.int32)
all_dout = []
for idx, ii in enumerate(self.descrpt_list):
dout = ii.build(
coord_,
atype_,
natoms,
box_,
mesh,
input_dict,
suffix=suffix + f"_{idx}",
reuse=reuse,
)
dout = tf.reshape(dout, [-1, ii.get_dim_out()])
all_dout.append(dout)
dout = tf.concat(all_dout, axis=1)
dout = tf.reshape(dout, [-1, natoms[0], self.get_dim_out()])
return dout
[docs] def prod_force_virial(
self, atom_ener: tf.Tensor, natoms: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Compute force and virial.
Parameters
----------
atom_ener
The atomic energy
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
Returns
-------
force
The force on atoms
virial
The total virial
atom_virial
The atomic virial
"""
for idx, ii in enumerate(self.descrpt_list):
ff, vv, av = ii.prod_force_virial(atom_ener, natoms)
if idx == 0:
force = ff
virial = vv
atom_virial = av
else:
force += ff
virial += vv
atom_virial += av
return force, virial, atom_virial
[docs] def enable_compression(
self,
min_nbor_dist: float,
graph: tf.Graph,
graph_def: tf.GraphDef,
table_extrapolate: float = 5.0,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
suffix: str = "",
) -> None:
"""Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the
training data.
Parameters
----------
min_nbor_dist : float
The nearest distance between atoms
graph : tf.Graph
The graph of the model
graph_def : tf.GraphDef
The graph_def of the model
table_extrapolate : float, default: 5.
The scale of model extrapolation
table_stride_1 : float, default: 0.01
The uniform stride of the first table
table_stride_2 : float, default: 0.1
The uniform stride of the second table
check_frequency : int, default: -1
The overflow check frequency
suffix : str, optional
The suffix of the scope
"""
for idx, ii in enumerate(self.descrpt_list):
ii.enable_compression(
min_nbor_dist,
graph,
graph_def,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
suffix=f"{suffix}_{idx}",
)
[docs] def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None:
"""Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
for idx, ii in enumerate(self.descrpt_list):
ii.enable_mixed_precision(mixed_prec)
[docs] def init_variables(
self,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix: str = "",
) -> None:
"""Init the embedding net variables with the given dict.
Parameters
----------
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str, optional
The suffix of the scope
"""
for idx, ii in enumerate(self.descrpt_list):
ii.init_variables(graph, graph_def, suffix=f"{suffix}_{idx}")
[docs] def get_tensor_names(self, suffix: str = "") -> Tuple[str]:
"""Get names of tensors.
Parameters
----------
suffix : str
The suffix of the scope
Returns
-------
Tuple[str]
Names of tensors
"""
tensor_names = []
for idx, ii in enumerate(self.descrpt_list):
tensor_names.extend(ii.get_tensor_names(suffix=f"{suffix}_{idx}"))
return tuple(tensor_names)
[docs] def pass_tensors_from_frz_model(
self,
*tensors: tf.Tensor,
) -> None:
"""Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def.
Parameters
----------
*tensors : tf.Tensor
passed tensors
"""
jj = 0
for ii in self.descrpt_list:
n_tensors = len(ii.get_tensor_names())
ii.pass_tensors_from_frz_model(*tensors[jj : jj + n_tensors])
jj += n_tensors
@property
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
return any(ii.explicit_ntypes for ii in self.descrpt_list)
[docs] @classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["list"] = [
Descriptor.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["list"]
]
return local_jdata_cpy