Program Listing for File commonTF.h
↰ Return to documentation for file (source/api_cc/include/commonTF.h
)
// SPDX-License-Identifier: LGPL-3.0-or-later
#include <string>
#include <vector>
#ifdef TF_PRIVATE
#include "tf_private.h"
#else
#include "tf_public.h"
#endif
namespace deepmd {
void check_status(const tensorflow::Status& status);
template <typename VT>
VT session_get_scalar(tensorflow::Session* session,
const std::string name,
const std::string scope = "");
template <typename VT>
void session_get_vector(std::vector<VT>& o_vec,
tensorflow::Session* session,
const std::string name_,
const std::string scope = "");
int session_get_dtype(tensorflow::Session* session,
const std::string name,
const std::string scope = "");
template <typename MODELTYPE, typename VALUETYPE>
int session_input_tensors(
std::vector<std::pair<std::string, tensorflow::Tensor>>& input_tensors,
const std::vector<VALUETYPE>& dcoord_,
const int& ntypes,
const std::vector<int>& datype_,
const std::vector<VALUETYPE>& dbox,
const double& cell_size,
const std::vector<VALUETYPE>& fparam_,
const std::vector<VALUETYPE>& aparam_,
const deepmd::AtomMap& atommap,
const std::string scope = "",
const bool aparam_nall = false);
template <typename MODELTYPE, typename VALUETYPE>
int session_input_tensors(
std::vector<std::pair<std::string, tensorflow::Tensor>>& input_tensors,
const std::vector<VALUETYPE>& dcoord_,
const int& ntypes,
const std::vector<int>& datype_,
const std::vector<VALUETYPE>& dbox,
InputNlist& dlist,
const std::vector<VALUETYPE>& fparam_,
const std::vector<VALUETYPE>& aparam_,
const deepmd::AtomMap& atommap,
const int nghost,
const int ago,
const std::string scope = "",
const bool aparam_nall = false);
template <typename MODELTYPE, typename VALUETYPE>
int session_input_tensors_mixed_type(
std::vector<std::pair<std::string, tensorflow::Tensor>>& input_tensors,
const int& nframes,
const std::vector<VALUETYPE>& dcoord_,
const int& ntypes,
const std::vector<int>& datype_,
const std::vector<VALUETYPE>& dbox,
const double& cell_size,
const std::vector<VALUETYPE>& fparam_,
const std::vector<VALUETYPE>& aparam_,
const deepmd::AtomMap& atommap,
const std::string scope = "",
const bool aparam_nall = false);
} // namespace deepmd