Source code for deepmd.nvnmd.entrypoints.wrap

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
    Optional,
)

import numpy as np

from deepmd.env import (
    op_module,
    tf,
)
from deepmd.nvnmd.data.data import (
    jdata_deepmd_input_v0,
    jdata_sys,
)
from deepmd.nvnmd.utils.config import (
    nvnmd_cfg,
)
from deepmd.nvnmd.utils.encode import (
    Encode,
)
from deepmd.nvnmd.utils.fio import (
    FioBin,
    FioTxt,
)
from deepmd.nvnmd.utils.network import (
    get_sess,
)
from deepmd.nvnmd.utils.weight import (
    get_fitnet_weight,
    get_type_weight,
)
from deepmd.utils.sess import (
    run_sess,
)

log = logging.getLogger(__name__)


[docs]class Wrap: r"""Generate the binary model file (model.pb). the model file can be use to run the NVNMD with lammps the pair style need set as: .. code-block:: lammps pair_style nvnmd model.pb pair_coeff * * Parameters ---------- config_file input file name an .npy file containing the configuration information of NVNMD model weight_file input file name an .npy file containing the weights of NVNMD model map_file input file name an .npy file containing the mapping tables of NVNMD model model_file output file name an .pb file containing the model using in the NVNMD References ---------- DOI: 10.1038/s41524-022-00773-z """ def __init__( self, config_file: str, weight_file: str, map_file: str, model_file: str ): self.config_file = config_file self.weight_file = weight_file self.map_file = map_file self.model_file = model_file jdata = jdata_deepmd_input_v0["nvnmd"] jdata["config_file"] = config_file jdata["weight_file"] = weight_file jdata["map_file"] = map_file jdata["enable"] = True nvnmd_cfg.init_from_jdata(jdata)
[docs] def wrap(self): e = Encode() # cfg bcfg = self.wrap_dscp() # split data with {nbit} bits per row hcfg = e.bin2hex(e.split_bin(bcfg, 72)) # the data must bigger than 128 hcfg = e.extend_list(hcfg, 128 if len(hcfg) < 128 else len(hcfg)) # bfps & bbps bfps, bbps = self.wrap_fitn() hfps = e.bin2hex(e.split_bin(bfps, 72)) hbps = e.bin2hex(e.split_bin(bbps, 72)) # bswt, bdsw, bfea, bgra bswt, bdsw, bfea, bgra = self.wrap_map() hswt = e.bin2hex(bswt) hdsw = e.bin2hex(bdsw) hfea = e.bin2hex(bfea) hgra = e.bin2hex(bgra) # bstd, bgtt, bavc if nvnmd_cfg.version == 1: bstd, bgtt, bavc = self.wrap_lut() hstd = e.bin2hex(bstd) hgtt = e.bin2hex(bgtt) havc = e.bin2hex(bavc) # extend data according to the number of bits per row of BRAM nbit = 32 if nvnmd_cfg.version == 0: datas = [hcfg, hfps, hbps, hswt, hdsw, hfea, hgra] keys = "cfg fps bps swt dsw fea gra".split() if nvnmd_cfg.version == 1: keys = "cfg fps bps swt dsw std fea gra gtt avc".split() datas = [hcfg, hfps, hbps, hswt, hdsw, hstd, hfea, hgra, hgtt, havc] nhs = [] nws = [] for ii in range(len(datas)): k = keys[ii] d = datas[ii] h = len(d) w = len(d[0]) # nhex w4 = w * 4 # nbit nhs.append(h) nws.append(w) # w_full = np.ceil(w4 / nbit) * nbit d = e.extend_hex(d, w_full) # DEVELOP_DEBUG if jdata_sys["debug"]: log.info("%s: %d x % d bit" % (k, h, w * 4)) FioTxt().save("nvnmd/wrap/h%s.txt" % (k), d) datas[ii] = d # update h & w of nvnmd_cfg nvnmd_cfg.size["NH_DATA"] = nhs nvnmd_cfg.size["NW_DATA"] = nws nvnmd_cfg.save(nvnmd_cfg.config_file) head = self.wrap_head(nhs, nws) # output model hs = [*head] for d in datas: hs.extend(d) FioBin().save(self.model_file, hs) log.info("NVNMD: finish wrapping model file")
[docs] def wrap_head(self, nhs, nws): r"""Wrap the head information. version nhead nheight nwidth rcut cut-off radius ntype number of atomic species nnei number of neighbors atom_ener atom bias energy """ nbit = nvnmd_cfg.nbit ctrl = nvnmd_cfg.ctrl dscp = nvnmd_cfg.dscp fitn = nvnmd_cfg.fitn weight = nvnmd_cfg.weight VERSION = ctrl["VERSION"] SUB_VERSION = ctrl["SUB_VERSION"] MAX_NNEI = ctrl["MAX_NNEI"] nhead = 128 NBIT_MODEL_HEAD = nbit["NBIT_MODEL_HEAD"] NBIT_FIXD_FL = nbit["NBIT_FIXD_FL"] rcut = dscp["rcut"] ntype = dscp["ntype"] SEL = dscp["SEL"] bs = "" e = Encode() # version vv = VERSION + 256 * SUB_VERSION + 256 * 256 * MAX_NNEI bs = e.dec2bin(vv, NBIT_MODEL_HEAD)[0] + bs # nhead bs = e.dec2bin(nhead, NBIT_MODEL_HEAD)[0] + bs # height for n in nhs: bs = e.dec2bin(n, NBIT_MODEL_HEAD)[0] + bs # width for n in nws: bs = e.dec2bin(n, NBIT_MODEL_HEAD)[0] + bs # rcut RCUT = e.qr(rcut, NBIT_FIXD_FL) bs = e.dec2bin(RCUT, NBIT_MODEL_HEAD)[0] + bs # ntype bs = e.dec2bin(ntype, NBIT_MODEL_HEAD)[0] + bs # nnei if VERSION == 0: for tt in range(ntype): bs = e.dec2bin(SEL[tt], NBIT_MODEL_HEAD)[0] + bs if VERSION == 1: bs = e.dec2bin(SEL, NBIT_MODEL_HEAD)[0] + bs # atom_ener # fix the bug: the different energy between qnn and lammps if "t_bias_atom_e" in weight.keys(): atom_ener = weight["t_bias_atom_e"] else: atom_ener = [0] * 32 nlayer_fit = fitn["nlayer_fit"] if VERSION == 0: for tt in range(ntype): w, b, _idt = get_fitnet_weight(weight, tt, nlayer_fit - 1, nlayer_fit) shift = atom_ener[tt] + b[0] SHIFT = e.qr(shift, NBIT_FIXD_FL) bs = e.dec2bin(SHIFT, NBIT_MODEL_HEAD, signed=True)[0] + bs if VERSION == 1: for tt in range(ntype): w, b, _idt = get_fitnet_weight(weight, 0, nlayer_fit - 1, nlayer_fit) shift = atom_ener[tt] + b[0] SHIFT = e.qr(shift, NBIT_FIXD_FL) bs = e.dec2bin(SHIFT, NBIT_MODEL_HEAD, signed=True)[0] + bs # extend hs = e.bin2hex(bs) hs = e.extend_hex(hs, NBIT_MODEL_HEAD * nhead) return hs
[docs] def wrap_dscp(self): r"""Wrap the configuration of descriptor. version 0: [NBIT_IDX_S2G-1:0] SHIFT_IDX_S2G [NBIT_NEIB*NTYPE-1:0] SELs [NBIT_FIXD*M1*NTYPE*NTYPE-1:0] GSs [NBIT_FLTE-1:0] NEXPO_DIV_NI version 1: [NBIT_FLTE-1:0] NEXPO_DIV_NI """ dscp = nvnmd_cfg.dscp nbit = nvnmd_cfg.nbit mapt = nvnmd_cfg.map bs = "" e = Encode() if nvnmd_cfg.version == 0: NBIT_IDX_S2G = nbit["NBIT_IDX_S2G"] NBIT_NEIB = nbit["NBIT_NEIB"] NBIT_FLTE = nbit["NBIT_FLTE"] NBIT_FIXD = nbit["NBIT_FIXD"] NBIT_FIXD_FL = nbit["NBIT_FIXD_FL"] M1 = dscp["M1"] ntype = dscp["ntype"] ntype_max = dscp["ntype_max"] # shift_idx_s2g x_st, x_ed, x_dt, N0, N1 = mapt["cfg_s2g"][0] shift_idx_s2g = int(np.round(-x_st / x_dt)) bs = e.dec2bin(shift_idx_s2g, NBIT_IDX_S2G)[0] + bs # sel SEL = dscp["SEL"] bs = e.dec2bin(SEL[0], NBIT_NEIB)[0] + bs bs = e.dec2bin(SEL[1], NBIT_NEIB)[0] + bs bs = e.dec2bin(SEL[2], NBIT_NEIB)[0] + bs bs = e.dec2bin(SEL[3], NBIT_NEIB)[0] + bs # GS tf.reset_default_graph() t_x = tf.placeholder(tf.float64, [None, 1], "t_x") t_table = tf.placeholder(tf.float64, [None, None], "t_table") t_table_grad = tf.placeholder(tf.float64, [None, None], "t_table_grad") t_table_info = tf.placeholder(tf.float64, [None], "t_table_info") t_y = op_module.map_flt_nvnmd(t_x, t_table, t_table_grad, t_table_info) sess = get_sess() # GS, when r = 0 GSs = [] for tt in range(ntype_max): for tt2 in range(ntype_max): if (tt < ntype) and (tt2 < ntype): # s mi = mapt["s"][tt] cfgs = mapt["cfg_u2s"] cfgs = np.array([np.float64(v) for vs in cfgs for v in vs]) feed_dict = { t_x: np.ones([1, 1]) * 0.0, t_table: mi, t_table_grad: mi * 0.0, t_table_info: cfgs, } si = run_sess(sess, t_y, feed_dict=feed_dict) si = np.reshape(si, [-1])[0] # G mi = mapt["g"][tt2] cfgs = mapt["cfg_s2g"] cfgs = np.array([np.float64(v) for vs in cfgs for v in vs]) feed_dict = { t_x: np.ones([1, 1]) * si, t_table: mi, t_table_grad: mi * 0.0, t_table_info: cfgs, } gi = run_sess(sess, t_y, feed_dict=feed_dict) gsi = np.reshape(si, [-1]) * np.reshape(gi, [-1]) else: gsi = np.zeros(M1) for ii in range(M1): GSs.extend( e.dec2bin(e.qr(gsi[ii], NBIT_FIXD_FL), NBIT_FIXD, True) ) sGSs = "".join(GSs[::-1]) bs = sGSs + bs # NIX = dscp["NIX"] ln2_NIX = -int(np.log2(NIX)) bs = e.dec2bin(ln2_NIX, NBIT_FLTE, signed=True)[0] + bs if nvnmd_cfg.version == 1: NBIT_FLTE = nbit["NBIT_FLTE"] NIX = dscp["NIX"] ln2_NIX = -int(np.log2(NIX)) bs = e.dec2bin(ln2_NIX, NBIT_FLTE, signed=True)[0] + bs return bs
[docs] def wrap_fitn(self): r"""Wrap the weights of fitting net. w weight b bias """ dscp = nvnmd_cfg.dscp fitn = nvnmd_cfg.fitn weight = nvnmd_cfg.weight nbit = nvnmd_cfg.nbit ctrl = nvnmd_cfg.ctrl if nvnmd_cfg.version == 0: ntype = dscp["ntype"] ntype_max = dscp["ntype_max"] if nvnmd_cfg.version == 1: ntype = 1 ntype_max = 1 nlayer_fit = fitn["nlayer_fit"] NNODE_FITS = fitn["NNODE_FITS"] NBIT_FIT_DATA = nbit["NBIT_FIT_DATA"] NBIT_FIT_DATA_FL = nbit["NBIT_FIT_DATA_FL"] NBIT_FIT_WEIGHT = nbit["NBIT_FIT_WEIGHT"] NBIT_FIT_DISP = nbit["NBIT_FIT_DISP"] NBIT_FIT_WXDB = nbit["NBIT_FIT_WXDB"] NSTDM = ctrl["NSTDM"] NSEL = ctrl["NSEL"] # encode all parameters bb, bdr, bdc, bwr, bwc = [], [], [], [], [] for ll in range(nlayer_fit): bbt, bdrt, bdct, bwrt, bwct = [], [], [], [], [] for tt in range(ntype_max): # get parameters: weight and bias if tt < ntype: w, b, _idt = get_fitnet_weight(weight, tt, ll, nlayer_fit) else: w, b, _idt = get_fitnet_weight(weight, 0, ll, nlayer_fit) w = w * 0 b = b * 0 # restrict the shift value of energy if ll == (nlayer_fit - 1): b = b * 0 bbi = self.wrap_bias(b, NBIT_FIT_WXDB, NBIT_FIT_DATA_FL) bdri, bdci, bwri, bwci = self.wrap_weight( w, NBIT_FIT_DISP, NBIT_FIT_WEIGHT ) bbt.append(bbi) bdrt.append(bdri) bdct.append(bdci) bwrt.append(bwri) bwct.append(bwci) bb.append(bbt) bdr.append(bdrt) bdc.append(bdct) bwr.append(bwrt) bwc.append(bwct) # bfps, bbps = [], [] for ss in range(NSEL): tt = ss // NSTDM sc = ss % NSTDM sr = ss % NSTDM bfp, bbp = [], [] for ll in range(nlayer_fit): nr = NNODE_FITS[ll] nc = NNODE_FITS[ll + 1] nrs = int(np.ceil(nr / NSTDM)) ncs = int(np.ceil(nc / NSTDM)) if nc == 1: # fp bfp += [ bwc[ll][tt][sr * nrs + rr][cc] for rr in range(nrs) for cc in range(nc) ] bfp += [bdc[ll][tt][sc * ncs * 0 + cc] for cc in range(ncs)] bfp += [bb[ll][tt][sc * ncs * 0 + cc] for cc in range(ncs)] # bp bbp += [ bwc[ll][tt][sr * nrs + rr][cc] for rr in range(nrs) for cc in range(nc) ] bbp += [bdc[ll][tt][sc * ncs * 0 + cc] for cc in range(ncs)] bbp += [bb[ll][tt][sc * ncs * 0 + cc] for cc in range(ncs)] else: # fp bfp += [ bwc[ll][tt][rr][sc * ncs + cc] for cc in range(ncs) for rr in range(nr) ] bfp += [bdc[ll][tt][sc * ncs + cc] for cc in range(ncs)] bfp += [bb[ll][tt][sc * ncs + cc] for cc in range(ncs)] # bp bbp += [ bwr[ll][tt][sr * nrs + rr][cc] for rr in range(nrs) for cc in range(nc) ] bbp += [bdr[ll][tt][sc * ncs + cc] for cc in range(ncs)] bbp += [bb[ll][tt][sc * ncs + cc] for cc in range(ncs)] bfps.append("".join(bfp[::-1])) bbps.append("".join(bbp[::-1])) return bfps, bbps
[docs] def wrap_bias(self, bias, NBIT_DATA, NBIT_DATA_FL): e = Encode() bias = e.qr(bias, NBIT_DATA_FL) Bs = e.dec2bin(bias, NBIT_DATA, True) return Bs
[docs] def wrap_weight(self, weight, NBIT_DISP, NBIT_WEIGHT): r"""weight: weights of fittingNet NBIT_DISP: nbits of exponent of weight max value NBIT_WEIGHT: nbits of mantissa of weights. """ NBIT_WEIGHT_FL = NBIT_WEIGHT - 2 sh = weight.shape nr, nc = sh[0], sh[1] nrs = np.zeros(nr) ncs = np.zeros(nc) wrs = np.zeros([nr, nc]) wcs = np.zeros([nr, nc]) e = Encode() # row for ii in range(nr): wi = weight[ii, :] wi, expo_max = e.norm_expo(wi, NBIT_WEIGHT_FL, 0) nrs[ii] = expo_max wrs[ii, :] = wi # column for ii in range(nc): wi = weight[:, ii] wi, expo_max = e.norm_expo(wi, NBIT_WEIGHT_FL, 0) ncs[ii] = expo_max wcs[:, ii] = wi NRs = e.dec2bin(nrs, NBIT_DISP) NCs = e.dec2bin(ncs, NBIT_DISP) wrs = e.qr(wrs, NBIT_WEIGHT_FL) WRs = e.dec2bin(wrs, NBIT_WEIGHT, True) WRs = [[WRs[nc * rr + cc] for cc in range(nc)] for rr in range(nr)] wcs = e.qr(wcs, NBIT_WEIGHT_FL) WCs = e.dec2bin(wcs, NBIT_WEIGHT, True) WCs = [[WCs[nc * rr + cc] for cc in range(nc)] for rr in range(nr)] return NRs, NCs, WRs, WCs
[docs] def wrap_map(self): r"""Wrap the mapping table of embedding network.""" dscp = nvnmd_cfg.dscp maps = nvnmd_cfg.map nbit = nvnmd_cfg.nbit M1 = dscp["M1"] NBIT_FLTE = nbit["NBIT_FLTE"] NBIT_FLTF = nbit["NBIT_FLTF"] if nvnmd_cfg.version == 0: ntype = dscp["ntype"] ntype_max = dscp["ntype_max"] if nvnmd_cfg.version == 1: ntype = 1 ntype_max = 1 e = Encode() # get mapt swts = [] dsws = [] feas = [] gras = [] for tt in range(ntype_max): if tt < ntype: swt = np.concatenate([maps["s"][tt], maps["h"][tt]], axis=1) dsw = np.concatenate([maps["s_grad"][tt], maps["h_grad"][tt]], axis=1) fea = maps["g"][tt] gra = maps["g_grad"][tt] else: swt = np.concatenate([maps["s"][0], maps["h"][0]], axis=1) dsw = np.concatenate([maps["s_grad"][0], maps["h_grad"][0]], axis=1) fea = maps["g"][0] gra = maps["g_grad"][0] swt *= 0 dsw *= 0 fea *= 0 gra *= 0 swts.append(swt.copy()) dsws.append(dsw.copy()) feas.append(fea.copy()) gras.append(gra.copy()) mapts = [swts, dsws, feas, gras] # reshape if nvnmd_cfg.version == 0: nmerges = [2 * 2, 2 * 2, 4 * 2, 4 * 2] bss = [] for ii in range(len(mapts)): d = mapts[ii] d = np.reshape(d, [ntype_max, -1, 4]) d1 = d[:, :, 0:2] d2 = d[:, :, 2:4] d = np.concatenate([d1, d2]) # bs = e.flt2bin(d, NBIT_FLTE, NBIT_FLTF) bs = e.reverse_bin(bs, nmerges[ii]) bs = e.merge_bin(bs, nmerges[ii]) bss.append(bs) if nvnmd_cfg.version == 1: ndim = [2, 2, M1, M1] bss = [] for ii in range(len(mapts)): nd = ndim[ii] d = mapts[ii] d = np.reshape(d, [-1, nd, 4]) d1 = np.reshape(d[:, :, 0:2], [-1, nd * 2]) d2 = np.reshape(d[:, :, 2:4], [-1, nd * 2]) d = np.concatenate([d1, d2], axis=1) # bs = e.flt2bin(d, NBIT_FLTE, NBIT_FLTF) bss.append(bs) bswt, bdsw, bfea, bgra = bss return bswt, bdsw, bfea, bgra
[docs] def wrap_lut(self): r"""Wrap the LUT.""" dscp = nvnmd_cfg.dscp fitn = nvnmd_cfg.fitn maps = nvnmd_cfg.map nbit = nvnmd_cfg.nbit weight = nvnmd_cfg.weight M1 = dscp["M1"] ntype = dscp["ntype"] ntype_max = dscp["ntype_max"] NBIT_FLTE = nbit["NBIT_FLTE"] NBIT_FLTF = nbit["NBIT_FLTF"] NBIT_DATA = nvnmd_cfg.nbit["NBIT_FIT_DATA"] NBIT_WXDB = nvnmd_cfg.nbit["NBIT_FIT_WXDB"] NBIT_DATA_FL = nvnmd_cfg.nbit["NBIT_FIT_DATA_FL"] e = Encode() # std d = maps["dstd_inv"] d2 = np.zeros([ntype_max, 2]) for ii in range(ntype): _d = d[ii, :2] _d = np.reshape(_d, [-1, 2]) _d = np.concatenate([_d[:, 0], _d[:, 1]], axis=0) d2[ii] = _d bstd = e.flt2bin(d2, NBIT_FLTE, NBIT_FLTF) # gtt d = maps["gt"] d2 = np.zeros([ntype_max**2, M1]) for ii in range(ntype): for jj in range(ntype): _d = d[ii * (ntype + 1) + jj] _d = np.reshape(_d, [-1, 2]) _d = np.concatenate([_d[:, 0], _d[:, 1]], axis=0) d2[ii * ntype_max + jj] = _d bgtt = e.flt2bin(d2, NBIT_FLTE, NBIT_FLTF) # avc d = maps["t_ebd"] w = get_type_weight(weight, 0) nd = w.shape[1] d2 = np.zeros([ntype_max, nd]) for ii in range(ntype): _d = d[ii] _d = np.reshape(_d, [1, -1]) _d = np.matmul(_d, w) # _d = np.reshape(_d, [-1, 2]) # _d = np.concatenate([_d[:,0], _d[:,1]], axis=0) d2[ii] = _d d2 = e.qr(d2, NBIT_DATA_FL) bavc = e.dec2bin(d2, NBIT_WXDB, True) return bstd, bgtt, bavc
[docs]def wrap( *, nvnmd_config: Optional[str] = "nvnmd/config.npy", nvnmd_weight: Optional[str] = "nvnmd/weight.npy", nvnmd_map: Optional[str] = "nvnmd/map.npy", nvnmd_model: Optional[str] = "nvnmd/model.pb", **kwargs, ): wrapObj = Wrap(nvnmd_config, nvnmd_weight, nvnmd_map, nvnmd_model) wrapObj.wrap()