Program Listing for File DataModifierTF.h#
↰ Return to documentation for file (source/api_cc/include/DataModifierTF.h
)
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once
#include "DataModifier.h"
#include "common.h"
#include "commonTF.h"
namespace deepmd {
class DipoleChargeModifierTF : public DipoleChargeModifierBase {
public:
DipoleChargeModifierTF();
DipoleChargeModifierTF(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
~DipoleChargeModifierTF();
void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
private:
template <typename VALUETYPE>
void compute(std::vector<VALUETYPE>& dfcorr_,
std::vector<VALUETYPE>& dvcorr_,
const std::vector<VALUETYPE>& dcoord_,
const std::vector<int>& datype_,
const std::vector<VALUETYPE>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<VALUETYPE>& delef_,
const int nghost,
const InputNlist& lmp_list);
public:
double cutoff() const {
assert(inited);
return rcut;
};
int numb_types() const {
assert(inited);
return ntypes;
};
const std::vector<int>& sel_types() const {
assert(inited);
return sel_type;
};
void computew(std::vector<double>& dfcorr_,
std::vector<double>& dvcorr_,
const std::vector<double>& dcoord_,
const std::vector<int>& datype_,
const std::vector<double>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<double>& delef_,
const int nghost,
const InputNlist& lmp_list);
void computew(std::vector<float>& dfcorr_,
std::vector<float>& dvcorr_,
const std::vector<float>& dcoord_,
const std::vector<int>& datype_,
const std::vector<float>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<float>& delef_,
const int nghost,
const InputNlist& lmp_list);
private:
tensorflow::Session* session;
std::string name_scope, name_prefix;
int num_intra_nthreads, num_inter_nthreads;
tensorflow::GraphDef* graph_def;
bool inited;
double rcut;
int dtype;
double cell_size;
int ntypes;
std::string model_type;
std::vector<int> sel_type;
template <class VT>
VT get_scalar(const std::string& name) const;
template <class VT>
void get_vector(std::vector<VT>& vec, const std::string& name) const;
template <typename MODELTYPE, typename VALUETYPE>
void run_model(std::vector<VALUETYPE>& dforce,
std::vector<VALUETYPE>& dvirial,
tensorflow::Session* session,
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
input_tensors,
const AtomMap& atommap,
const int nghost);
};
} // namespace deepmd