Program Listing for File env_mat_nvnmd.h

Return to documentation for file (source/lib/include/env_mat_nvnmd.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
/*
//==================================================
 _   _  __     __  _   _   __  __   ____
| \ | | \ \   / / | \ | | |  \/  | |  _ \
|  \| |  \ \ / /  |  \| | | |\/| | | | | |
| |\  |   \ V /   | |\  | | |  | | | |_| |
|_| \_|    \_/    |_| \_| |_|  |_| |____/

//==================================================

code: nvnmd
reference: deepmd
author: mph (pinghui_mo@outlook.com)
date: 2021-12-6

*/

#pragma once

#include <cmath>
#include <vector>

#include "env_mat_nvnmd.h"
#include "utilities.h"

namespace deepmd {

template <typename FPTYPE>
void env_mat_a_nvnmd_quantize_cpu(std::vector<FPTYPE> &descrpt_a,
                                  std::vector<FPTYPE> &descrpt_a_deriv,
                                  std::vector<FPTYPE> &rij_a,
                                  const std::vector<FPTYPE> &posi,
                                  const std::vector<int> &type,
                                  const int &i_idx,
                                  const std::vector<int> &fmt_nlist,
                                  const std::vector<int> &sec,
                                  const float &rmin,
                                  const float &rmax);
}

union U_Flt64_Int64 {
  double nflt;
  int64_t nint;
};

/* 21-bit fraction */
// #define NBIT_FLTF 21
// #define NBIT_CUTF (52 - NBIT_FLTF)
// #define FLT_MASK 0xffffffff80000000

/* 20-bit fraction */
#define NBIT_FLTF 20
#define NBIT_CUTF (52 - NBIT_FLTF)
#define FLT_MASK 0xffffffff00000000

/*
  split double into sign, expo, and frac
*/
template <class T>  // float and double
void split_flt(T x, int64_t &sign, int64_t &expo, int64_t &mant) {
  U_Flt64_Int64 ufi;
  ufi.nflt = x;
  sign = (ufi.nint >> 63) & 0x01;
  expo = ((ufi.nint >> 52) & 0x7ff) - 1023;
  mant = (ufi.nint & 0xfffffffffffff) | 0x10000000000000;  // 1+52
}

/*
 find the max exponent for float array x
*/
template <class T>  // float and double
void find_max_expo(int64_t &max_expo, T *x, int64_t M) {
  int ii, jj, kk;
  U_Flt64_Int64 ufi;
  int64_t expo;
  max_expo = -100;
  for (jj = 0; jj < M; jj++) {
    ufi.nflt = x[jj];
    expo = ((ufi.nint >> 52) & 0x7ff) - 1023;
    max_expo = (expo > max_expo) ? expo : max_expo;
  }
};

/*
 find the max exponent for float array x
*/
template <class T>  // float and double
void find_max_expo(int64_t &max_expo, T *x, int64_t N, int64_t M) {
  int ii, jj, kk;
  U_Flt64_Int64 ufi;
  int64_t expo;
  max_expo = -100;
  for (ii = 0; ii < N; ii++) {
    ufi.nflt = x[ii * M];
    expo = ((ufi.nint >> 52) & 0x7ff) - 1023;
    max_expo = (expo > max_expo) ? expo : max_expo;
  }
};

/*
 dot multiply
*/
template <class T>  // float and double
void dotmul_flt_nvnmd(T &y, T *x1, T *x2, int64_t M) {
  int ii, jj, kk;
  U_Flt64_Int64 ufi;
  //
  int64_t sign1, sign2, sign3;
  int64_t expo1, expo2, expo3;
  int64_t mant1, mant2, mant3;
  int64_t expos;
  //
  int64_t expo_max1, expo_max2;
  //
  find_max_expo(expo_max1, x1, M);
  find_max_expo(expo_max2, x2, M);
  //
  int64_t s = 0;
  for (jj = 0; jj < M; jj++) {
    // x1
    split_flt(x1[jj], sign1, expo1, mant1);
    mant1 >>= NBIT_CUTF;
    expos = expo_max1 - expo1;
    expos = (expos > 63) ? 63 : expos;
    mant1 >>= expos;
    // x2
    split_flt(x2[jj], sign2, expo2, mant2);
    mant2 >>= NBIT_CUTF;
    expos = expo_max2 - expo2;
    expos = (expos > 63) ? 63 : expos;
    mant2 >>= expos;
    // multiply
    mant3 = mant1 * mant2;
    mant3 = (sign1 ^ sign2) ? -mant3 : mant3;
    s += mant3;
  }
  // y * 2^(e_a+e_b)
  ufi.nflt = T(s) * pow(2.0, expo_max1 + expo_max2 - NBIT_FLTF - NBIT_FLTF);
  ufi.nint &= FLT_MASK;
  y = ufi.nflt;
}

/*
  multiply
*/
template <class T>  // float and double
void mul_flt_nvnmd(T &y, T x1, T x2) {
  U_Flt64_Int64 ufi1, ufi2, ufi3;
  ufi1.nflt = x1;
  ufi1.nint &= FLT_MASK;
  ufi2.nflt = x2;
  ufi2.nint &= FLT_MASK;
  ufi3.nflt = ufi2.nflt * ufi1.nflt;
  ufi3.nint &= FLT_MASK;
  y = ufi3.nflt;
}

/*
  add
*/
template <class T>  // float and double
void add_flt_nvnmd(T &y, T x1, T x2) {
  U_Flt64_Int64 ufi1, ufi2, ufi3;
  int64_t sign1, sign2, sign3;
  int64_t expo1, expo2, expo3;
  int64_t mant1, mant2, mant3;
  int64_t expos;
  // convert data
  split_flt(x1, sign1, expo1, mant1);
  mant1 >>= NBIT_CUTF;

  split_flt(x2, sign2, expo2, mant2);
  mant2 >>= NBIT_CUTF;

  // shift
  if (expo1 >= expo2) {
    expo3 = expo1;
    expos = expo1 - expo2;
    expos = (expos > 63) ? 63 : expos;
    mant2 >>= expos;
  } else {
    expo3 = expo2;
    expos = expo2 - expo1;
    expos = (expos > 63) ? 63 : expos;
    mant1 >>= expos;
  }

  // add
  mant1 = (sign1) ? -mant1 : mant1;
  mant2 = (sign2) ? -mant2 : mant2;
  mant3 = mant1 + mant2;
  // fix2flt
  ufi3.nflt = double(mant3) * pow(2.0, expo3 - NBIT_FLTF);
  ufi3.nint &= FLT_MASK;
  y = ufi3.nflt;
}