Source code for deepmd.utils.convert

import os
from deepmd.env import tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile


def convert_13_to_21(input_model: str, output_model: str):
    """Convert DP 1.3 graph to 2.1 graph.
    
    Parameters
    ----------
    input_model : str
        filename of the input graph
    output_model : str
        filename of the output graph
    """
    convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
    convert_dp13_to_dp20('frozen_model.pbtxt')
    convert_dp20_to_dp21('frozen_model.pbtxt')
    convert_pbtxt_to_pb('frozen_model.pbtxt', output_model)
    if os.path.isfile('frozen_model.pbtxt'):
        os.remove('frozen_model.pbtxt')
    print("the converted output model (2.1 support) is saved in %s" % output_model)


[docs]def convert_13_to_21(input_model: str, output_model: str): """Convert DP 1.3 graph to 2.1 graph. Parameters ---------- input_model : str filename of the input graph output_model : str filename of the output graph """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') convert_dp13_to_dp20('frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) if os.path.isfile('frozen_model.pbtxt'): os.remove('frozen_model.pbtxt') print("the converted output model (2.1 support) is saved in %s" % output_model)
[docs]def convert_12_to_21(input_model: str, output_model: str): """Convert DP 1.2 graph to 2.1 graph. Parameters ---------- input_model : str filename of the input graph output_model : str filename of the output graph """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') convert_dp12_to_dp13('frozen_model.pbtxt') convert_dp13_to_dp20('frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) if os.path.isfile('frozen_model.pbtxt'): os.remove('frozen_model.pbtxt') print("the converted output model (2.1 support) is saved in %s" % output_model)
[docs]def convert_10_to_21(input_model: str, output_model: str): """Convert DP 1.0 graph to 2.1 graph. Parameters ---------- input_model : str filename of the input graph output_model : str filename of the output graph """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') convert_dp10_to_dp11('frozen_model.pbtxt') convert_dp12_to_dp13('frozen_model.pbtxt') convert_dp13_to_dp20('frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) if os.path.isfile('frozen_model.pbtxt'): os.remove('frozen_model.pbtxt') print("the converted output model (2.1 support) is saved in %s" % output_model)
[docs]def convert_012_to_21(input_model: str, output_model: str): """Convert DP 0.12 graph to 2.1 graph. Parameters ---------- input_model : str filename of the input graph output_model : str filename of the output graph """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') convert_dp012_to_dp10('frozen_model.pbtxt') convert_dp10_to_dp11('frozen_model.pbtxt') convert_dp12_to_dp13('frozen_model.pbtxt') convert_dp13_to_dp20('frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) if os.path.isfile('frozen_model.pbtxt'): os.remove('frozen_model.pbtxt') print("the converted output model (2.1 support) is saved in %s" % output_model)
[docs]def convert_20_to_21(input_model: str, output_model: str): """Convert DP 2.0 graph to 2.1 graph. Parameters ---------- input_model : str filename of the input graph output_model : str filename of the output graph """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) if os.path.isfile('frozen_model.pbtxt'): os.remove('frozen_model.pbtxt') print("the converted output model (2.1 support) is saved in %s" % output_model)
[docs]def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str): """Convert DP graph to graph text. Parameters ---------- pbfile : str filename of the input graph pbtxtfile : str filename of the output graph text """ with gfile.FastGFile(pbfile, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True)
[docs]def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str): """Convert DP graph text to graph. Parameters ---------- pbtxtfile : str filename of the input graph text pbfile : str filename of the output graph """ with tf.gfile.FastGFile(pbtxtfile, 'r') as f: graph_def = tf.GraphDef() file_content = f.read() # Merges the human-readable string in `file_content` into `graph_def`. text_format.Merge(file_content, graph_def) tf.train.write_graph(graph_def, './', pbfile, as_text=False)
[docs]def convert_dp012_to_dp10(file: str): """Convert DP 1.0 graph text to 1.1 graph text. Parameters ---------- file : str filename of the graph text """ with open(file) as fp: file_content = fp.read() file_content = file_content\ .replace('DescrptNorot', 'DescrptSeA') \ .replace('ProdForceNorot', 'ProdForceSeA') \ .replace('ProdVirialNorot', 'ProdVirialSeA') with open(file, 'w') as fp: fp.write(file_content)
[docs]def convert_dp10_to_dp11(file: str): """Convert DP 1.0 graph text to 1.1 graph text. Parameters ---------- file : str filename of the graph text """ with open(file, 'a') as f: f.write(""" node { name: "fitting_attr/daparam" op: "Const" attr { key: "dtype" value { type: DT_INT32 } } attr { key: "value" value { tensor { dtype: DT_INT32 tensor_shape { } int_val: 0 } } } } """)
[docs]def convert_dp12_to_dp13(file: str): """Convert DP 1.2 graph text to 1.3 graph text. Parameters ---------- file : str filename of the graph text """ file_data = "" with open(file, "r", encoding="utf-8") as f: ii = 0 lines = f.readlines() while (ii < len(lines)): line = lines[ii] file_data += line ii+=1 if 'name' in line and ('DescrptSeA' in line or 'ProdForceSeA' in line or 'ProdVirialSeA' in line): while not('attr' in lines[ii] and '{' in lines[ii]): file_data += lines[ii] ii+=1 file_data += ' attr {\n' file_data += ' key: \"T\"\n' file_data += ' value {\n' file_data += ' type: DT_DOUBLE\n' file_data += ' }\n' file_data += ' }\n' with open(file, "w", encoding="utf-8") as f: f.write(file_data)
[docs]def convert_dp13_to_dp20(fname: str): """Convert DP 1.3 graph text to 2.0 graph text. Parameters ---------- file : str filename of the graph text """ with open(fname) as fp: file_content = fp.read() file_content += """ node { name: "model_attr/model_version" op: "Const" attr { key: "dtype" value { type: DT_STRING } } attr { key: "value" value { tensor { dtype: DT_STRING tensor_shape { } string_val: "1.0" } } } } """ file_content = file_content\ .replace('DescrptSeA', 'ProdEnvMatA')\ .replace('DescrptSeR', 'ProdEnvMatR') with open(fname, 'w') as fp: fp.write(file_content)
[docs]def convert_dp20_to_dp21(fname: str): with open(fname) as fp: file_content = fp.read() old_model_version_node = """ node { name: "model_attr/model_version" op: "Const" attr { key: "dtype" value { type: DT_STRING } } attr { key: "value" value { tensor { dtype: DT_STRING tensor_shape { } string_val: "1.0" } } } } """ new_model_version_node = """ node { name: "model_attr/model_version" op: "Const" attr { key: "dtype" value { type: DT_STRING } } attr { key: "value" value { tensor { dtype: DT_STRING tensor_shape { } string_val: "1.1" } } } } """ file_content = file_content\ .replace(old_model_version_node, new_model_version_node)\ .replace('TabulateFusion', 'TabulateFusionSeA')\ .replace('TabulateFusionGrad', 'TabulateFusionSeAGrad')\ .replace('TabulateFusionGradGrad', 'TabulateFusionSeAGradGrad') with open(fname, 'w') as fp: fp.write(file_content)