Source code for deepmd.nvnmd.descriptor.se_a

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging

import numpy as np

from deepmd.env import (
    GLOBAL_NP_FLOAT_PRECISION,
    GLOBAL_TF_FLOAT_PRECISION,
    op_module,
    tf,
)

#
from deepmd.nvnmd.utils.config import (
    nvnmd_cfg,
)
from deepmd.nvnmd.utils.weight import (
    get_normalize,
)
from deepmd.utils.graph import (
    get_tensor_by_name_from_graph,
)
from deepmd.utils.network import (
    embedding_net,
)

log = logging.getLogger(__name__)


[docs]def build_davg_dstd(): r"""Get the davg and dstd from the dictionary nvnmd_cfg. The davg and dstd have been obtained by training CNN. """ davg, dstd = get_normalize(nvnmd_cfg.weight) return davg, dstd
[docs]def check_switch_range(davg, dstd): r"""Check the range of switch, let it in range [-2, 14].""" rmin = nvnmd_cfg.dscp["rcut_smth"] # namelist = [n.name for n in tf.get_default_graph().as_graph_def().node] if "train_attr/min_nbor_dist" in namelist: min_dist = get_tensor_by_name_from_graph( tf.get_default_graph(), "train_attr/min_nbor_dist" ) elif "train_attr.min_nbor_dist" in nvnmd_cfg.weight.keys(): if nvnmd_cfg.weight["train_attr.min_nbor_dist"] < 1e-6: min_dist = rmin else: min_dist = nvnmd_cfg.weight["train_attr.min_nbor_dist"] else: min_dist = None # fix the bug: if model initial mode is 'init_from_model', # we need dmin to calculate smin and smax in mapt.py if min_dist is not None: nvnmd_cfg.dscp["dmin"] = min_dist nvnmd_cfg.save() # if davg and dstd is None, the model initial mode is in # 'init_from_model', 'restart', 'init_from_frz_model', 'finetune' if (davg is not None) and (dstd is not None): nvnmd_cfg.get_s_range(davg, dstd)
[docs]def build_op_descriptor(): r"""Replace se_a.py/DescrptSeA/build.""" if nvnmd_cfg.quantize_descriptor: return op_module.prod_env_mat_a_nvnmd_quantize else: return op_module.prod_env_mat_a
[docs]def descrpt2r4(inputs, natoms): r"""Replace :math:`r_{ji} \rightarrow r'_{ji}` where :math:`r_{ji} = (x_{ji}, y_{ji}, z_{ji})` and :math:`r'_{ji} = (s_{ji}, \frac{s_{ji} x_{ji}}{r_{ji}}, \frac{s_{ji} y_{ji}}{r_{ji}}, \frac{s_{ji} z_{ji}}{r_{ji}})`. """ ntypes = nvnmd_cfg.dscp["ntype"] NIDP = nvnmd_cfg.dscp["NIDP"] ndescrpt = NIDP * 4 start_index = 0 # (nf*na*ni, 4) inputs_reshape = tf.reshape(inputs, [-1, 4]) with tf.variable_scope("filter_type_all_x", reuse=True): # u (i.e., r^2) u = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]), [-1, 1]) with tf.variable_scope("u", reuse=True): u = op_module.flt_nvnmd(u) log.debug("#u: %s", u) u = tf.ensure_shape(u, [None, 1]) u = tf.reshape(u, [-1, natoms[0] * NIDP]) sh0 = tf.shape(u)[0] # rij rij = tf.reshape(tf.slice(inputs_reshape, [0, 1], [-1, 3]), [-1, 3]) with tf.variable_scope("rij", reuse=True): rij = op_module.flt_nvnmd(rij) rij = tf.ensure_shape(rij, [None, 3]) log.debug("#rij: %s", rij) s = [] h = [] for type_i in range(ntypes): type_input = 0 u_i = tf.slice(u, [0, start_index * NIDP], [-1, natoms[2 + type_i] * NIDP]) u_i = tf.reshape(u_i, [-1, 1]) # s table = GLOBAL_NP_FLOAT_PRECISION( np.concatenate( [nvnmd_cfg.map["s"][type_i], nvnmd_cfg.map["h"][type_i]], axis=1 ) ) table_grad = GLOBAL_NP_FLOAT_PRECISION( np.concatenate( [nvnmd_cfg.map["s_grad"][type_i], nvnmd_cfg.map["h_grad"][type_i]], axis=1, ) ) table_info = nvnmd_cfg.map["cfg_u2s"] table_info = np.array([np.float64(v) for vs in table_info for v in vs]) table_info = GLOBAL_NP_FLOAT_PRECISION(table_info) s_h_i = op_module.map_flt_nvnmd(u_i, table, table_grad, table_info) s_h_i = tf.ensure_shape(s_h_i, [None, 1, 2]) s_i = tf.slice(s_h_i, [0, 0, 0], [-1, -1, 1]) h_i = tf.slice(s_h_i, [0, 0, 1], [-1, -1, 1]) # reshape shape to sh0 for fixing bug. # This bug occurs if the number of atoms of an element is zero. s_i = tf.reshape(s_i, [sh0, natoms[2 + type_i] * NIDP]) h_i = tf.reshape(h_i, [sh0, natoms[2 + type_i] * NIDP]) s.append(s_i) h.append(h_i) start_index += natoms[2 + type_i] s = tf.concat(s, axis=1) h = tf.concat(h, axis=1) s = tf.reshape(s, [-1, 1]) h = tf.reshape(h, [-1, 1]) with tf.variable_scope("s", reuse=True): s = op_module.flt_nvnmd(s) log.debug("#s: %s", s) s = tf.ensure_shape(s, [None, 1]) with tf.variable_scope("h", reuse=True): h = op_module.flt_nvnmd(h) log.debug("#h: %s", h) h = tf.ensure_shape(h, [None, 1]) # R2R4 Rs = s # Rxyz = h * rij Rxyz = op_module.mul_flt_nvnmd(h, rij) Rxyz = tf.ensure_shape(Rxyz, [None, 3]) with tf.variable_scope("Rxyz", reuse=True): Rxyz = op_module.flt_nvnmd(Rxyz) log.debug("#Rxyz: %s", Rxyz) Rxyz = tf.ensure_shape(Rxyz, [None, 3]) R4 = tf.concat([Rs, Rxyz], axis=1) R4 = tf.reshape(R4, [-1, NIDP, 4]) inputs_reshape = R4 inputs_reshape = tf.reshape(inputs_reshape, [-1, ndescrpt]) return inputs_reshape
[docs]def filter_lower_R42GR( type_i, type_input, inputs_i, is_exclude, activation_fn, bavg, stddev, trainable, suffix, seed, seed_shift, uniform_seed, filter_neuron, filter_precision, filter_resnet_dt, embedding_net_variables, ): r"""Replace se_a.py/DescrptSeA/_filter_lower.""" shape_i = inputs_i.get_shape().as_list() inputs_reshape = tf.reshape(inputs_i, [-1, 4]) natom = tf.shape(inputs_i)[0] M1 = nvnmd_cfg.dscp["M1"] type_input = 0 if (type_input < 0) else type_input if nvnmd_cfg.quantize_descriptor: # copy inputs_reshape = op_module.flt_nvnmd(inputs_reshape) inputs_reshape = tf.ensure_shape(inputs_reshape, [None, 4]) inputs_reshape, inputs_reshape2 = op_module.copy_flt_nvnmd(inputs_reshape) inputs_reshape = tf.ensure_shape(inputs_reshape, [None, 4]) inputs_reshape2 = tf.ensure_shape(inputs_reshape2, [None, 4]) # s s = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]), [-1, 1]) # G table = GLOBAL_NP_FLOAT_PRECISION(nvnmd_cfg.map["g"][type_i]) table_grad = GLOBAL_NP_FLOAT_PRECISION(nvnmd_cfg.map["g_grad"][type_i]) table_info = nvnmd_cfg.map["cfg_s2g"] table_info = np.array([np.float64(v) for vs in table_info for v in vs]) table_info = GLOBAL_NP_FLOAT_PRECISION(table_info) with tf.variable_scope("g", reuse=True): G = op_module.map_flt_nvnmd(s, table, table_grad, table_info) G = tf.ensure_shape(G, [None, 1, M1]) G = op_module.flt_nvnmd(G) G = tf.ensure_shape(G, [None, 1, M1]) log.debug("#g: %s", G) # G xyz_scatter = G xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1] // 4, M1)) # GR inputs_reshape2 = tf.reshape(inputs_reshape2, [-1, shape_i[1] // 4, 4]) GR = op_module.matmul_flt2fix_nvnmd( tf.transpose(inputs_reshape2, [0, 2, 1]), xyz_scatter, 23 ) GR = tf.ensure_shape(GR, [None, 4, M1]) return GR else: xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]), [-1, 1]) if nvnmd_cfg.restore_descriptor: trainable = False embedding_net_variables = {} for key in nvnmd_cfg.weight.keys(): if "filter_type" in key: key2 = key.replace(".", "/") embedding_net_variables[key2] = nvnmd_cfg.weight[key] if not is_exclude: xyz_scatter = embedding_net( xyz_scatter, filter_neuron, filter_precision, activation_fn=activation_fn, resnet_dt=filter_resnet_dt, name_suffix=suffix, stddev=stddev, bavg=bavg, seed=seed, trainable=trainable, uniform_seed=uniform_seed, initial_variables=embedding_net_variables, ) if (not uniform_seed) and (seed is not None): seed += seed_shift else: # we can safely return the final xyz_scatter filled with zero directly return tf.cast(tf.fill((natom, 4, M1), 0.0), GLOBAL_TF_FLOAT_PRECISION) # natom x nei_type_i x out_size xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1] // 4, M1)) # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below # [588 24] -> [588 6 4] correct # but if sel is zero # [588 0] -> [147 0 4] incorrect; the correct one is [588 0 4] # So we need to explicitly assign the shape to tf.shape(inputs_i)[0] instead of -1 return tf.matmul( tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]), xyz_scatter, transpose_a=True, )
[docs]def filter_GR2D(xyz_scatter_1): r"""Replace se_a.py/_filter.""" NIX = nvnmd_cfg.dscp["NIX"] M1 = nvnmd_cfg.dscp["M1"] M2 = nvnmd_cfg.dscp["M2"] NBIT_DATA_FL = nvnmd_cfg.nbit["NBIT_FIXD_FL"] if nvnmd_cfg.quantize_descriptor: xyz_scatter_1 = tf.reshape(xyz_scatter_1, [-1, 4 * M1]) # fix the number of bits of gradient xyz_scatter_1 = xyz_scatter_1 * (1.0 / NIX) with tf.variable_scope("gr", reuse=True): xyz_scatter_1 = op_module.flt_nvnmd(xyz_scatter_1) log.debug("#gr: %s", xyz_scatter_1) xyz_scatter_1 = tf.ensure_shape(xyz_scatter_1, [None, 4 * M1]) xyz_scatter_1 = tf.reshape(xyz_scatter_1, [-1, 4, M1]) # natom x 4 x outputs_size_2 xyz_scatter_2 = xyz_scatter_1 # natom x 3 x outputs_size_1 qmat = tf.slice(xyz_scatter_1, [0, 1, 0], [-1, 3, -1]) # natom x outputs_size_2 x 3 qmat = tf.transpose(qmat, perm=[0, 2, 1]) # D': natom x outputs_size x outputs_size_2 xyz_scatter_1_T = tf.transpose(xyz_scatter_1, [0, 2, 1]) result = op_module.matmul_flt_nvnmd( xyz_scatter_1_T, xyz_scatter_2, 1 * 16 + 0, 1 * 16 + 0 ) result = tf.ensure_shape(result, [None, M1, M1]) # D': natom x (outputs_size x outputs_size_2) result = tf.reshape(result, [-1, M1 * M1]) # index_subset = [] for ii in range(M1): for jj in range(ii, ii + M2): index_subset.append((ii * M1) + (jj % M1)) index_subset = tf.constant(np.int32(np.array(index_subset))) result = tf.gather(result, index_subset, axis=1) with tf.variable_scope("d", reuse=True): result = op_module.flt_nvnmd(result) log.debug("#d: %s", result) result = tf.ensure_shape(result, [None, M1 * M2]) result = op_module.quantize_nvnmd(result, 0, NBIT_DATA_FL, NBIT_DATA_FL, -1) result = tf.ensure_shape(result, [None, M1 * M2]) else: # natom x 4 x outputs_size xyz_scatter_1 = xyz_scatter_1 * (1.0 / NIX) # natom x 4 x outputs_size_2 # xyz_scatter_2 = tf.slice(xyz_scatter_1, [0,0,0],[-1,-1,outputs_size_2]) xyz_scatter_2 = xyz_scatter_1 # natom x 3 x outputs_size_1 qmat = tf.slice(xyz_scatter_1, [0, 1, 0], [-1, 3, -1]) # natom x outputs_size_1 x 3 qmat = tf.transpose(qmat, perm=[0, 2, 1]) # natom x outputs_size x outputs_size_2 result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a=True) # natom x (outputs_size x outputs_size_2) # result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) result = tf.reshape(result, [-1, M1 * M1]) # index_subset = [] for ii in range(M1): for jj in range(ii, ii + M2): index_subset.append((ii * M1) + (jj % M1)) index_subset = tf.constant(np.int32(np.array(index_subset))) result = tf.gather(result, index_subset, axis=1) return result, qmat