Source code for deepmd.nvnmd.utils.encode

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging

import numpy as np

from deepmd.nvnmd.data.data import (
    jdata_sys,
)

log = logging.getLogger(__name__)


[docs]class Encode: r"""Encoding value as hex, bin, and dec format.""" def __init__(self): pass
[docs] def qr(self, v, nbit: int = 14): r"""Quantize value using round.""" return np.round(v * (2**nbit))
[docs] def qf(self, v, nbit: int = 14): r"""Quantize value using floor.""" return np.floor(v * (2**nbit))
[docs] def qc(self, v, nbit: int = 14): r"""Quantize value using ceil.""" return np.ceil(v * (2**nbit))
[docs] def split_expo_mant(self, v, min=-1000): vabs = np.abs(v) expo = np.log2(vabs) expo = np.maximum(expo, min) prec = 1.0 / 2.0**expo mant = v * prec return expo, mant
[docs] def find_max_expo(self, v, expo_min=-1000): vabs = np.abs(v) vmax = np.max(vabs) expo_max = np.log2(vmax + 1e-50) expo_max = np.maximum(expo_max, expo_min) expo_max = np.floor(expo_max) return expo_max
[docs] def norm_expo(self, v, nbit_frac=20, expo_min=-1000): expo_max = self.find_max_expo(v, expo_min) prec_expo = 2 ** (nbit_frac - expo_max) prec = 2**nbit_frac sign = np.sign(v) vabs = np.abs(v) vabs = np.floor(vabs * prec_expo) / prec return sign * vabs, expo_max
[docs] def flt2bin_one(self, v, nbit_expo, nbit_frac): v = float(v) # 64-bit float h = v.hex() n = len(h) st = n for ii in range(n): if h[ii] == "x": st = ii + 1 if h[ii] == "p": ed = ii + 1 is_zero = h[st] == "0" # if is_zero: return "0" * (1 + nbit_expo + nbit_frac) else: s = "1" if h[0] == "-" else "0" e = int(h[ed:]) + int(2 ** (nbit_expo - 1) + 2**nbit_expo) e = bin(e)[3:] # 0b1xxxxxxx fh = h[st + 2 : ed - 1] fb = self.hex2bin_str(fh) f = fb[0:nbit_frac] return s + e + f
[docs] def flt2bin(self, data, nbit_expo, nbit_frac): r"""Convert float into binary string list.""" data = np.reshape(np.array(data), [-1]) return [self.flt2bin_one(d, nbit_expo, nbit_frac) for d in data]
[docs] def byte2hex(self, bs, nbyte): r"""Convert byte into hex bs: low byte in the first hex: low byte in the right. """ nl = len(bs) // nbyte hs = [] for ii in range(nl): b = bs[nbyte * ii : nbyte * (ii + 1)] b = b[::-1] h = b.hex() hs.append(h) return hs
[docs] def check_dec(self, idec, nbit, signed=False, name=""): r"""Check whether the data (idec) is in the range range is :math:`[0, 2^nbit-1]` for unsigned range is :math:`[-2^{nbit-1}, 2^{nbit-1}-1]` for signed. """ prec = np.int64(2**nbit) if signed: pmax = prec // 2 - 1 pmin = -pmax else: pmax = prec - 1 pmin = 0 I1 = idec < pmin I2 = idec > pmax if jdata_sys["debug"]: if np.sum(I1) > 0: log.warning( f"NVNMD: there are data {name} smaller than the lower limit {pmin}" ) if np.sum(I2) > 0: log.warning( f"NVNMD: there are data {name} bigger than the upper limit {pmax}" )
[docs] def extend_list(self, slbin, nfull): r"""Extend the list (slbin) to the length (nfull) the attched element of list is 0. such as, when | slbin = ['10010','10100'], | nfull = 4 extent it to ['10010','10100','00000','00000] """ nfull = int(nfull) n = len(slbin) dn = nfull - n ds = "0" * len(slbin[0]) return slbin + [ds for ii in range(dn)]
[docs] def extend_bin(self, slbin, nfull): r"""Extend the element of list (slbin) to the length (nfull). such as, when | slbin = ['10010','10100'], | nfull = 6 extent to ['010010','010100'] """ nfull = int(nfull) n = len(slbin[0]) dn = nfull - n ds = "0" * int(dn) return [ds + s for s in slbin]
[docs] def extend_hex(self, slhex, nfull): r"""Extend the element of list (slhex) to the length (nfull).""" nfull = int(nfull) n = len(slhex[0]) dn = (nfull // 4) - n ds = "0" * int(dn) return [ds + s for s in slhex]
[docs] def split_bin(self, sbin, nbit: int): r"""Split sbin into many segment with the length nbit.""" if isinstance(sbin, list): sl = [] for s in sbin: sl.extend(self.split_bin(s, nbit)) return sl else: n = len(sbin) nseg = int(np.ceil(n / nbit)) s = "0" * int(nseg * nbit - n) sbin = s + sbin sl = [sbin[ii * nbit : (ii + 1) * nbit] for ii in range(nseg)] sl = sl[::-1] return sl
[docs] def reverse_bin(self, slbin, nreverse): r"""Reverse binary string list per `nreverse` value.""" nreverse = int(nreverse) # consider that {len(slbin)} can not be divided by {nreverse} without remainder n = int(np.ceil(len(slbin) / nreverse)) slbin = self.extend_list(slbin, n * nreverse) return [ slbin[ii * nreverse + nreverse - 1 - jj] for ii in range(n) for jj in range(nreverse) ]
[docs] def merge_bin(self, slbin, nmerge): r"""Merge binary string list per `nmerge` value.""" nmerge = int(nmerge) # consider that {len(slbin)} can not be divided by {nmerge} without remainder n = int(np.ceil(len(slbin) / nmerge)) slbin = self.extend_list(slbin, n * nmerge) return ["".join(slbin[nmerge * ii : nmerge * (ii + 1)]) for ii in range(n)]
[docs] def dec2bin(self, idec, nbit=10, signed=False, name=""): r"""Convert dec array to binary string list.""" idec = np.int64(np.reshape(np.array(idec), [-1])) self.check_dec(idec, nbit, signed, name) prec = np.int64(2**nbit) if signed: pmax = prec // 2 - 1 pmin = -pmax else: pmax = prec - 1 pmin = 0 idec = np.maximum(pmin, idec) idec = np.minimum(pmax, idec) idec = idec + 2 * prec sl = [] n = len(idec) for ii in range(n): s = bin(idec[ii]) s = s[-nbit:] sl.append(s) return sl
[docs] def hex2bin_str(self, shex): r"""Convert hex string to binary string.""" n = len(shex) sl = [] for ii in range(n): si = bin(int(shex[ii], 16) + 16) sl.append(si[-4:]) return "".join(sl)
[docs] def hex2bin(self, data): r"""Convert hex string list to binary string list.""" data = np.reshape(np.array(data), [-1]) return [self.hex2bin_str(d) for d in data]
[docs] def bin2hex_str(self, sbin): r"""Convert binary string to hex string.""" n = len(sbin) nx = int(np.ceil(n / 4)) sbin = ("0" * (nx * 4 - n)) + sbin sl = [] for ii in range(nx): si = hex(int(sbin[4 * ii : 4 * (ii + 1)], 2) + 16) sl.append(si[-1]) return "".join(sl)
[docs] def bin2hex(self, data): r"""Convert binary string list to hex string list.""" data = np.reshape(np.array(data), [-1]) return [self.bin2hex_str(d) for d in data]