Program Listing for File DeepTensor.h
↰ Return to documentation for file (source/api_cc/include/DeepTensor.h
)
#pragma once
#include "common.h"
#include "neighbor_list.h"
namespace deepmd{
class DeepTensor
{
public:
DeepTensor();
DeepTensor(const std::string & model,
const int & gpu_rank = 0,
const std::string &name_scope = "");
void init (const std::string & model,
const int & gpu_rank = 0,
const std::string &name_scope = "");
void print_summary(const std::string &pre) const;
public:
void compute (std::vector<VALUETYPE> & value,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box);
void compute (std::vector<VALUETYPE> & value,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box,
const int nghost,
const InputNlist & inlist);
void compute (std::vector<VALUETYPE> & global_tensor,
std::vector<VALUETYPE> & force,
std::vector<VALUETYPE> & virial,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box);
void compute (std::vector<VALUETYPE> & global_tensor,
std::vector<VALUETYPE> & force,
std::vector<VALUETYPE> & virial,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box,
const int nghost,
const InputNlist & inlist);
void compute (std::vector<VALUETYPE> & global_tensor,
std::vector<VALUETYPE> & force,
std::vector<VALUETYPE> & virial,
std::vector<VALUETYPE> & atom_tensor,
std::vector<VALUETYPE> & atom_virial,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box);
void compute (std::vector<VALUETYPE> & global_tensor,
std::vector<VALUETYPE> & force,
std::vector<VALUETYPE> & virial,
std::vector<VALUETYPE> & atom_tensor,
std::vector<VALUETYPE> & atom_virial,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box,
const int nghost,
const InputNlist & inlist);
VALUETYPE cutoff () const {assert(inited); return rcut;};
int numb_types () const {assert(inited); return ntypes;};
int output_dim () const {assert(inited); return odim;};
const std::vector<int> & sel_types () const {assert(inited); return sel_type;};
private:
tensorflow::Session* session;
std::string name_scope;
int num_intra_nthreads, num_inter_nthreads;
tensorflow::GraphDef graph_def;
bool inited;
VALUETYPE rcut;
VALUETYPE cell_size;
int ntypes;
std::string model_type;
std::string model_version;
int odim;
std::vector<int> sel_type;
template<class VT> VT get_scalar(const std::string & name) const;
template<class VT> void get_vector (std::vector<VT> & vec, const std::string & name) const;
void run_model (std::vector<VALUETYPE> & d_tensor_,
tensorflow::Session * session,
const std::vector<std::pair<std::string, tensorflow::Tensor>> & input_tensors,
const AtomMap<VALUETYPE> & atommap,
const std::vector<int> & sel_fwd,
const int nghost = 0);
void run_model (std::vector<VALUETYPE> & dglobal_tensor_,
std::vector<VALUETYPE> & dforce_,
std::vector<VALUETYPE> & dvirial_,
std::vector<VALUETYPE> & datom_tensor_,
std::vector<VALUETYPE> & datom_virial_,
tensorflow::Session * session,
const std::vector<std::pair<std::string, tensorflow::Tensor>> & input_tensors,
const AtomMap<VALUETYPE> & atommap,
const std::vector<int> & sel_fwd,
const int nghost = 0);
void compute_inner (std::vector<VALUETYPE> & value,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box);
void compute_inner (std::vector<VALUETYPE> & value,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box,
const int nghost,
const InputNlist& inlist);
void compute_inner (std::vector<VALUETYPE> & global_tensor,
std::vector<VALUETYPE> & force,
std::vector<VALUETYPE> & virial,
std::vector<VALUETYPE> & atom_tensor,
std::vector<VALUETYPE> & atom_virial,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box);
void compute_inner (std::vector<VALUETYPE> & global_tensor,
std::vector<VALUETYPE> & force,
std::vector<VALUETYPE> & virial,
std::vector<VALUETYPE> & atom_tensor,
std::vector<VALUETYPE> & atom_virial,
const std::vector<VALUETYPE> & coord,
const std::vector<int> & atype,
const std::vector<VALUETYPE> & box,
const int nghost,
const InputNlist& inlist);
};
}