deepmd.tf.model.frozen

Module Contents

Classes

FrozenModel

Load model from a frozen model, which cannot be trained.

class deepmd.tf.model.frozen.FrozenModel(model_file: str, **kwargs)[source]

Bases: deepmd.tf.model.model.Model

Load model from a frozen model, which cannot be trained.

Parameters:
model_filestr

The path to the frozen model

build(coord_: deepmd.tf.env.tf.Tensor, atype_: deepmd.tf.env.tf.Tensor, natoms: deepmd.tf.env.tf.Tensor, box: deepmd.tf.env.tf.Tensor, mesh: deepmd.tf.env.tf.Tensor, input_dict: dict, frz_model: str | None = None, ckpt_meta: str | None = None, suffix: str = '', reuse: bool | enum.Enum | None = None) dict[source]

Build the model.

Parameters:
coord_tf.Tensor

The coordinates of atoms

atype_tf.Tensor

The atom types of atoms

natomstf.Tensor

The number of atoms

boxtf.Tensor

The box vectors

meshtf.Tensor

The mesh vectors

input_dictdict

The input dict

frz_modelstr, optional

The path to the frozen model

ckpt_metastr, optional

The path prefix of the checkpoint and meta files

suffixstr, optional

The suffix of the scope

reusebool or tf.AUTO_REUSE, optional

Whether to reuse the variables

Returns:
dict

The output dict

get_fitting() deepmd.tf.fit.fitting.Fitting | dict[source]

Get the fitting(s).

get_loss(loss: dict, lr) deepmd.tf.loss.loss.Loss | dict | None[source]

Get the loss function(s).

get_rcut()[source]

Get cutoff radius of the model.

get_ntypes() int[source]

Get the number of types.

data_stat(data)[source]

Data staticis.

init_variables(graph: deepmd.tf.env.tf.Graph, graph_def: deepmd.tf.env.tf.GraphDef, model_type: str = 'original_model', suffix: str = '') None[source]

Init the embedding net variables with the given frozen model.

Parameters:
graphtf.Graph

The input frozen model graph

graph_deftf.GraphDef

The input frozen model graph_def

model_typestr

the type of the model

suffixstr

suffix to name scope

enable_compression(suffix: str = '') None[source]

Enable compression.

Parameters:
suffixstr

suffix to name scope

get_type_map() list[source]

Get the type map.

classmethod update_sel(global_jdata: dict, local_jdata: dict)[source]

Update the selection and perform neighbor statistics.

Parameters:
global_jdatadict

The global data, containing the training section

local_jdatadict

The local data refer to the current class

serialize(suffix: str = '') dict[source]

Serialize the model.

There is no suffix in a native DP model, but it is important for the TF backend.

Returns:
dict

The serialized data

suffixstr, optional

Name suffix to identify this descriptor

classmethod deserialize(data: dict, suffix: str = '')[source]

Deserialize the model.

There is no suffix in a native DP model, but it is important for the TF backend.

Parameters:
datadict

The serialized data

suffixstr, optional

Name suffix to identify this model

Returns:
Model

The deserialized Model