# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
)
from deepmd.utils.spin import (
Spin,
)
from .descriptor import (
Descriptor,
)
from .se_a import (
DescrptSeA,
)
log = logging.getLogger(__name__)
[docs]@Descriptor.register("se_a_tpe_v2")
@Descriptor.register("se_a_ebd_v2")
class DescrptSeAEbdV2(DescrptSeA):
r"""A compressible se_a_ebd model.
This model is a warpper for DescriptorSeA, which set stripped_type_embedding=True.
"""
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: List[int],
neuron: List[int] = [24, 48, 96],
axis_neuron: int = 8,
resnet_dt: bool = False,
trainable: bool = True,
seed: Optional[int] = None,
type_one_side: bool = True,
exclude_types: List[List[int]] = [],
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
**kwargs,
) -> None:
DescrptSeA.__init__(
self,
rcut,
rcut_smth,
sel,
neuron=neuron,
axis_neuron=axis_neuron,
resnet_dt=resnet_dt,
trainable=trainable,
seed=seed,
type_one_side=type_one_side,
exclude_types=exclude_types,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
uniform_seed=uniform_seed,
multi_task=multi_task,
spin=spin,
stripped_type_embedding=True,
**kwargs,
)