# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import numpy as np
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
op_module,
tf,
)
#
from deepmd.nvnmd.utils.config import (
nvnmd_cfg,
)
from deepmd.nvnmd.utils.weight import (
get_normalize,
)
from deepmd.utils.graph import (
get_tensor_by_name_from_graph,
)
log = logging.getLogger(__name__)
[docs]def build_davg_dstd():
r"""Get the davg and dstd from the dictionary nvnmd_cfg.
The davg and dstd have been obtained by training CNN.
"""
davg, dstd = get_normalize(nvnmd_cfg.weight)
return davg, dstd
[docs]def check_switch_range(davg, dstd):
r"""Check the range of switch, let it in range [-2, 14]."""
rmin = nvnmd_cfg.dscp["rcut_smth"]
ntype = nvnmd_cfg.dscp["ntype"]
NIDP = nvnmd_cfg.dscp["NIDP"]
ndescrpt = NIDP * 4
#
namelist = [n.name for n in tf.get_default_graph().as_graph_def().node]
if "train_attr/min_nbor_dist" in namelist:
min_dist = get_tensor_by_name_from_graph(
tf.get_default_graph(), "train_attr/min_nbor_dist"
)
elif "train_attr.min_nbor_dist" in nvnmd_cfg.weight.keys():
if nvnmd_cfg.weight["train_attr.min_nbor_dist"] < 1e-6:
min_dist = rmin
else:
min_dist = nvnmd_cfg.weight["train_attr.min_nbor_dist"]
else:
min_dist = None
# fix the bug: if model initial mode is 'init_from_model',
# we need dmin to calculate smin and smax in mapt.py
if min_dist is not None:
nvnmd_cfg.dscp["dmin"] = min_dist
nvnmd_cfg.save()
# if davg and dstd is None, the model initial mode is in
# 'init_from_model', 'restart', 'init_from_frz_model', 'finetune'
if (davg is not None) or (dstd is not None):
if davg is None:
davg = np.zeros([ntype, ndescrpt])
if dstd is None:
dstd = np.ones([ntype, ndescrpt])
nvnmd_cfg.get_s_range(davg, dstd)
[docs]def build_op_descriptor():
r"""Replace se_a.py/DescrptSeA/build."""
if nvnmd_cfg.quantize_descriptor:
return op_module.prod_env_mat_a_mix_nvnmd_quantize
else:
return op_module.prod_env_mat_a_mix
[docs]def descrpt2r4(inputs, atype):
r"""Replace :math:`r_{ji} \rightarrow r'_{ji}`
where :math:`r_{ji} = (x_{ji}, y_{ji}, z_{ji})` and
:math:`r'_{ji} = (s_{ji}, \frac{s_{ji} x_{ji}}{r_{ji}}, \frac{s_{ji} y_{ji}}{r_{ji}}, \frac{s_{ji} z_{ji}}{r_{ji}})`.
"""
NIDP = nvnmd_cfg.dscp["NIDP"]
ndescrpt = NIDP * 4
# (nf*na*ni, 4)
inputs_reshape = tf.reshape(inputs, [-1, 4])
# u (i.e., r^2)
u = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]), [-1, 1])
with tf.variable_scope("u", reuse=True):
u = op_module.flt_nvnmd(u)
log.debug("#u: %s", u)
u = tf.ensure_shape(u, [None, 1])
# rji
rji = tf.reshape(tf.slice(inputs_reshape, [0, 1], [-1, 3]), [-1, 3])
with tf.variable_scope("rji", reuse=True):
rji = op_module.flt_nvnmd(rji)
rji = tf.ensure_shape(rji, [None, 3])
log.debug("#rji: %s", rji)
# s & h
u = tf.reshape(u, [-1, 1])
table = GLOBAL_NP_FLOAT_PRECISION(
np.concatenate([nvnmd_cfg.map["s"][0], nvnmd_cfg.map["h"][0]], axis=1)
)
table_grad = GLOBAL_NP_FLOAT_PRECISION(
np.concatenate(
[nvnmd_cfg.map["s_grad"][0], nvnmd_cfg.map["h_grad"][0]],
axis=1,
)
)
table_info = nvnmd_cfg.map["cfg_u2s"]
table_info = np.array([np.float64(v) for vs in table_info for v in vs])
table_info = GLOBAL_NP_FLOAT_PRECISION(table_info)
s_h = op_module.map_flt_nvnmd(u, table, table_grad, table_info)
s_h = tf.ensure_shape(s_h, [None, 1, 2])
s = tf.slice(s_h, [0, 0, 0], [-1, -1, 1])
h = tf.slice(s_h, [0, 0, 1], [-1, -1, 1])
s = tf.reshape(s, [-1, 1])
h = tf.reshape(h, [-1, 1])
with tf.variable_scope("s_s", reuse=True):
s = op_module.flt_nvnmd(s)
log.debug("#s_s: %s", s)
s = tf.ensure_shape(s, [None, 1])
with tf.variable_scope("h_s", reuse=True):
h = op_module.flt_nvnmd(h)
log.debug("#h_s: %s", h)
h = tf.ensure_shape(h, [None, 1])
# davg and dstd
# davg = nvnmd_cfg.map["davg"] # is_zero
dstd_inv = nvnmd_cfg.map["dstd_inv"]
atype_expand = tf.reshape(atype, [-1, 1])
std_inv_sel = tf.nn.embedding_lookup(dstd_inv, atype_expand)
std_inv_sel = tf.reshape(std_inv_sel, [-1, 4])
std_inv_s = tf.slice(std_inv_sel, [0, 0], [-1, 1])
std_inv_h = tf.slice(std_inv_sel, [0, 1], [-1, 1])
s = op_module.mul_flt_nvnmd(std_inv_s, tf.reshape(s, [-1, NIDP]))
h = op_module.mul_flt_nvnmd(std_inv_h, tf.reshape(h, [-1, NIDP]))
s = tf.ensure_shape(s, [None, NIDP])
h = tf.ensure_shape(h, [None, NIDP])
s = tf.reshape(s, [-1, 1])
h = tf.reshape(h, [-1, 1])
with tf.variable_scope("s", reuse=True):
s = op_module.flt_nvnmd(s)
log.debug("#s: %s", s)
s = tf.ensure_shape(s, [None, 1])
with tf.variable_scope("h", reuse=True):
h = op_module.flt_nvnmd(h)
log.debug("#h: %s", h)
h = tf.ensure_shape(h, [None, 1])
# R2R4
Rs = s
# Rxyz = h * rji
Rxyz = op_module.mul_flt_nvnmd(h, rji)
Rxyz = tf.ensure_shape(Rxyz, [None, 3])
with tf.variable_scope("Rxyz", reuse=True):
Rxyz = op_module.flt_nvnmd(Rxyz)
log.debug("#Rxyz: %s", Rxyz)
Rxyz = tf.ensure_shape(Rxyz, [None, 3])
R4 = tf.concat([Rs, Rxyz], axis=1)
inputs_reshape = R4
inputs_reshape = tf.reshape(inputs_reshape, [-1, ndescrpt])
return inputs_reshape
[docs]def filter_lower_R42GR(inputs_i, atype, nei_type_vec):
r"""Replace se_a.py/DescrptSeA/_filter_lower."""
shape_i = inputs_i.get_shape().as_list()
inputs_reshape = tf.reshape(inputs_i, [-1, 4])
M1 = nvnmd_cfg.dscp["M1"]
ntype = nvnmd_cfg.dscp["ntype"]
NIDP = nvnmd_cfg.dscp["NIDP"]
two_embd_value = nvnmd_cfg.map["gt"]
# print(two_embd_value)
# copy
inputs_reshape = op_module.flt_nvnmd(inputs_reshape)
inputs_reshape = tf.ensure_shape(inputs_reshape, [None, 4])
inputs_reshape, inputs_reshape2 = op_module.copy_flt_nvnmd(inputs_reshape)
inputs_reshape = tf.ensure_shape(inputs_reshape, [None, 4])
inputs_reshape2 = tf.ensure_shape(inputs_reshape2, [None, 4])
# s2G
s = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]), [-1, 1])
table = GLOBAL_NP_FLOAT_PRECISION(nvnmd_cfg.map["g"][0])
table_grad = GLOBAL_NP_FLOAT_PRECISION(nvnmd_cfg.map["g_grad"][0])
table_info = nvnmd_cfg.map["cfg_s2g"]
table_info = np.array([np.float64(v) for vs in table_info for v in vs])
table_info = GLOBAL_NP_FLOAT_PRECISION(table_info)
G = op_module.map_flt_nvnmd(s, table, table_grad, table_info)
G = tf.ensure_shape(G, [None, 1, M1])
with tf.variable_scope("g_s", reuse=True):
G = op_module.flt_nvnmd(G)
log.debug("#g_s: %s", G)
G = tf.ensure_shape(G, [None, 1, M1])
# t2G
atype_expand = tf.reshape(atype, [-1, 1])
idx_i = tf.tile(atype_expand * (ntype + 1), [1, NIDP])
idx_j = tf.reshape(nei_type_vec, [-1, NIDP])
idx = idx_i + idx_j
index_of_two_side = tf.reshape(idx, [-1])
two_embd = tf.nn.embedding_lookup(two_embd_value, index_of_two_side)
# two_embd = tf.reshape(two_embd, (-1, shape_i[1] // 4, M1))
two_embd = tf.reshape(two_embd, (-1, M1))
with tf.variable_scope("g_t", reuse=True):
two_embd = op_module.flt_nvnmd(two_embd)
log.debug("#g_t: %s", two_embd)
two_embd = tf.ensure_shape(two_embd, [None, M1])
# G_s, G_t -> G
G = tf.reshape(G, [-1, M1])
G = op_module.mul_flt_nvnmd(G, two_embd)
G = tf.ensure_shape(G, [None, M1])
with tf.variable_scope("g", reuse=True):
G = op_module.flt_nvnmd(G)
log.debug("#g: %s", G)
G = tf.ensure_shape(G, [None, M1])
xyz_scatter = G
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1] // 4, M1))
# GR
inputs_reshape2 = tf.reshape(inputs_reshape2, [-1, shape_i[1] // 4, 4])
GR = op_module.matmul_flt2fix_nvnmd(
tf.transpose(inputs_reshape2, [0, 2, 1]), xyz_scatter, 23
)
GR = tf.ensure_shape(GR, [None, 4, M1])
return GR
[docs]def filter_GR2D(xyz_scatter_1):
r"""Replace se_a.py/_filter."""
NIX = nvnmd_cfg.dscp["NIX"]
M1 = nvnmd_cfg.dscp["M1"]
M2 = nvnmd_cfg.dscp["M2"]
NBIT_DATA_FL = nvnmd_cfg.nbit["NBIT_FIXD_FL"]
if nvnmd_cfg.quantize_descriptor:
xyz_scatter_1 = tf.reshape(xyz_scatter_1, [-1, 4 * M1])
# fix the number of bits of gradient
xyz_scatter_1 = xyz_scatter_1 * (1.0 / NIX)
with tf.variable_scope("gr", reuse=True):
xyz_scatter_1 = op_module.flt_nvnmd(xyz_scatter_1)
log.debug("#gr: %s", xyz_scatter_1)
xyz_scatter_1 = tf.ensure_shape(xyz_scatter_1, [None, 4 * M1])
xyz_scatter_1 = tf.reshape(xyz_scatter_1, [-1, 4, M1])
# natom x 4 x outputs_size_2
xyz_scatter_2 = xyz_scatter_1
# natom x 3 x outputs_size_1
qmat = tf.slice(xyz_scatter_1, [0, 1, 0], [-1, 3, -1])
# natom x outputs_size_2 x 3
qmat = tf.transpose(qmat, perm=[0, 2, 1])
# D': natom x outputs_size x outputs_size_2
xyz_scatter_1_T = tf.transpose(xyz_scatter_1, [0, 2, 1])
result = op_module.matmul_flt_nvnmd(
xyz_scatter_1_T, xyz_scatter_2, 1 * 16 + 0, 1 * 16 + 0
)
result = tf.ensure_shape(result, [None, M1, M1])
# D': natom x (outputs_size x outputs_size_2)
result = tf.reshape(result, [-1, M1 * M1])
#
index_subset = []
for ii in range(M1):
for jj in range(ii, ii + M2):
index_subset.append((ii * M1) + (jj % M1))
index_subset = tf.constant(np.int32(np.array(index_subset)))
result = tf.gather(result, index_subset, axis=1)
with tf.variable_scope("d", reuse=True):
result = op_module.flt_nvnmd(result)
log.debug("#d: %s", result)
result = tf.ensure_shape(result, [None, M1 * M2])
result = op_module.quantize_nvnmd(result, 0, NBIT_DATA_FL, NBIT_DATA_FL, -1)
result = tf.ensure_shape(result, [None, M1 * M2])
else:
# natom x 4 x outputs_size
xyz_scatter_1 = xyz_scatter_1 * (1.0 / NIX)
# natom x 4 x outputs_size_2
# xyz_scatter_2 = tf.slice(xyz_scatter_1, [0,0,0],[-1,-1,outputs_size_2])
xyz_scatter_2 = xyz_scatter_1
# natom x 3 x outputs_size_1
qmat = tf.slice(xyz_scatter_1, [0, 1, 0], [-1, 3, -1])
# natom x outputs_size_1 x 3
qmat = tf.transpose(qmat, perm=[0, 2, 1])
# natom x outputs_size x outputs_size_2
result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a=True)
# natom x (outputs_size x outputs_size_2)
# result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]])
result = tf.reshape(result, [-1, M1 * M1])
#
index_subset = []
for ii in range(M1):
for jj in range(ii, ii + M2):
index_subset.append((ii * M1) + (jj % M1))
index_subset = tf.constant(np.int32(np.array(index_subset)))
result = tf.gather(result, index_subset, axis=1)
return result, qmat