Source code for deepmd.entrypoints.freeze

#!/usr/bin/env python3
"""Script for freezing TF trained graph so it can be used with LAMMPS and i-PI.

References
----------
https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
"""

from deepmd.env import tf
from deepmd.env import op_module
from deepmd.utils.sess import run_sess
from os.path import abspath

# load grad of force module
import deepmd.op

from typing import List, Optional

__all__ = ["freeze"]


def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> List[str]:
    """Get node names based on model type.

    Parameters
    ----------
    model_type : str
        str type of model
    modifier_type : Optional[str], optional
        modifier type if any, by default None

    Returns
    -------
    List[str]
        list with all node names to freeze

    Raises
    ------
    RuntimeError
        if unknown model type
    """
    nodes = [
        "model_type",
        "descrpt_attr/rcut",
        "descrpt_attr/ntypes",
        "model_attr/tmap",
        "model_attr/model_type",
        "model_attr/model_version",
        "train_attr/min_nbor_dist",
        "train_attr/training_script",
    ]

    if model_type == "ener":
        nodes += [
            "o_energy",
            "o_force",
            "o_virial",
            "o_atom_energy",
            "o_atom_virial",
            "fitting_attr/dfparam",
            "fitting_attr/daparam",
        ]
    elif model_type == "wfc":
        nodes += [
            "o_wfc",
            "model_attr/sel_type",
            "model_attr/output_dim",
        ]
    elif model_type == "dipole":
        nodes += [
            "o_dipole",
            "o_global_dipole",
            "o_force",
            "o_virial",
            "o_atom_virial",
            "o_rmat",
            "o_rmat_deriv",
            "o_nlist",
            "o_rij",
            "descrpt_attr/sel",
            "descrpt_attr/ndescrpt",
            "model_attr/sel_type",
            "model_attr/output_dim",
        ]
    elif model_type == "polar":
        nodes += [
            "o_polar",
            "o_global_polar",
            "o_force",
            "o_virial",
            "o_atom_virial",
            "model_attr/sel_type",
            "model_attr/output_dim",
        ]
    elif model_type == "global_polar":
        nodes += [
            "o_global_polar",
            "model_attr/sel_type",
            "model_attr/output_dim",
        ]
    else:
        raise RuntimeError(f"unknow model type {model_type}")
    if modifier_type == "dipole_charge":
        nodes += [
            "modifier_attr/type",
            "modifier_attr/mdl_name",
            "modifier_attr/mdl_charge_map",
            "modifier_attr/sys_charge_map",
            "modifier_attr/ewald_h",
            "modifier_attr/ewald_beta",
            "dipole_charge/descrpt_attr/rcut",
            "dipole_charge/descrpt_attr/ntypes",
            "dipole_charge/model_attr/tmap",
            "dipole_charge/model_attr/model_type",
            "o_dm_force",
            "dipole_charge/model_attr/sel_type",
            "dipole_charge/o_dipole",
            "dipole_charge/model_attr/output_dim",
            "o_dm_virial",
            "o_dm_av",
        ]
    return nodes


[docs]def freeze( *, checkpoint_folder: str, output: str, node_names: Optional[str] = None, **kwargs ): """Freeze the graph in supplied folder. Parameters ---------- checkpoint_folder : str location of the folder with model output : str output file name node_names : Optional[str], optional names of nodes to output, by default None """ # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(checkpoint_folder) input_checkpoint = checkpoint.model_checkpoint_path # expand the output file to full path output_graph = abspath(output) # Before exporting our graph, we need to precise what is our output node # This is how TF decides what part of the Graph he has to keep # and what part it can dump # NOTE: this variable is plural, because you can have multiple output nodes # node_names = "energy_test,force_test,virial_test,t_rcut" # We clear devices to allow TensorFlow to control # on which device it will load operations clear_devices = True # We import the meta graph and retrieve a Saver saver = tf.train.import_meta_graph( f"{input_checkpoint}.meta", clear_devices=clear_devices ) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() nodes = [n.name for n in input_graph_def.node] # We start a session and restore the graph weights with tf.Session() as sess: saver.restore(sess, input_checkpoint) model_type = run_sess(sess, "model_attr/model_type:0", feed_dict={}).decode("utf-8") if "modifier_attr/type" in nodes: modifier_type = run_sess(sess, "modifier_attr/type:0", feed_dict={}).decode( "utf-8" ) else: modifier_type = None if node_names is None: output_node_list = _make_node_names(model_type, modifier_type) else: output_node_list = node_names.split(",") print(f"The following nodes will be frozen: {output_node_list}") # We use a built-in TF helper to export variables to constants output_graph_def = tf.graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights input_graph_def, # The graph_def is used to retrieve the nodes output_node_list, # The output node names are used to select the usefull nodes ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print(f"{len(output_graph_def.node):d} ops in the final graph.")