Program Listing for File DeepTensorTF.h

Return to documentation for file (source/api_cc/include/DeepTensorTF.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include "DeepTensor.h"
#include "common.h"
#include "neighbor_list.h"

namespace deepmd {
class DeepTensorTF : public DeepTensorBase {
 public:
  DeepTensorTF();
  ~DeepTensorTF();
  DeepTensorTF(const std::string& model,
               const int& gpu_rank = 0,
               const std::string& name_scope = "");
  void init(const std::string& model,
            const int& gpu_rank = 0,
            const std::string& name_scope = "");

 private:
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& value,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& value,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& inlist);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& global_tensor,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_tensor,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& global_tensor,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_tensor,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& inlist);

 public:
  double cutoff() const {
    assert(inited);
    return rcut;
  };
  int numb_types() const {
    assert(inited);
    return ntypes;
  };
  int output_dim() const {
    assert(inited);
    return odim;
  };
  const std::vector<int>& sel_types() const {
    assert(inited);
    return sel_type;
  };
  void get_type_map(std::string& type_map);

  void computew(std::vector<double>& global_tensor,
                std::vector<double>& force,
                std::vector<double>& virial,
                std::vector<double>& atom_tensor,
                std::vector<double>& atom_virial,
                const std::vector<double>& coord,
                const std::vector<int>& atype,
                const std::vector<double>& box,
                const bool request_deriv);
  void computew(std::vector<float>& global_tensor,
                std::vector<float>& force,
                std::vector<float>& virial,
                std::vector<float>& atom_tensor,
                std::vector<float>& atom_virial,
                const std::vector<float>& coord,
                const std::vector<int>& atype,
                const std::vector<float>& box,
                const bool request_deriv);
  void computew(std::vector<double>& global_tensor,
                std::vector<double>& force,
                std::vector<double>& virial,
                std::vector<double>& atom_tensor,
                std::vector<double>& atom_virial,
                const std::vector<double>& coord,
                const std::vector<int>& atype,
                const std::vector<double>& box,
                const int nghost,
                const InputNlist& inlist,
                const bool request_deriv);
  void computew(std::vector<float>& global_tensor,
                std::vector<float>& force,
                std::vector<float>& virial,
                std::vector<float>& atom_tensor,
                std::vector<float>& atom_virial,
                const std::vector<float>& coord,
                const std::vector<int>& atype,
                const std::vector<float>& box,
                const int nghost,
                const InputNlist& inlist,
                const bool request_deriv);
 private:
  tensorflow::Session* session;
  std::string name_scope;
  int num_intra_nthreads, num_inter_nthreads;
  tensorflow::GraphDef* graph_def;
  bool inited;
  double rcut;
  int dtype;
  double cell_size;
  int ntypes;
  std::string model_type;
  std::string model_version;
  int odim;
  std::vector<int> sel_type;
  template <class VT>
  VT get_scalar(const std::string& name) const;
  template <class VT>
  void get_vector(std::vector<VT>& vec, const std::string& name) const;
  template <typename MODELTYPE, typename VALUETYPE>
  void run_model(std::vector<VALUETYPE>& d_tensor_,
                 tensorflow::Session* session,
                 const std::vector<std::pair<std::string, tensorflow::Tensor>>&
                     input_tensors,
                 const AtomMap& atommap,
                 const std::vector<int>& sel_fwd,
                 const int nghost = 0);
  template <typename MODELTYPE, typename VALUETYPE>
  void run_model(std::vector<VALUETYPE>& dglobal_tensor_,
                 std::vector<VALUETYPE>& dforce_,
                 std::vector<VALUETYPE>& dvirial_,
                 std::vector<VALUETYPE>& datom_tensor_,
                 std::vector<VALUETYPE>& datom_virial_,
                 tensorflow::Session* session,
                 const std::vector<std::pair<std::string, tensorflow::Tensor>>&
                     input_tensors,
                 const AtomMap& atommap,
                 const std::vector<int>& sel_fwd,
                 const int nghost = 0);
  template <typename VALUETYPE>
  void compute_inner(std::vector<VALUETYPE>& value,
                     const std::vector<VALUETYPE>& coord,
                     const std::vector<int>& atype,
                     const std::vector<VALUETYPE>& box);
  template <typename VALUETYPE>
  void compute_inner(std::vector<VALUETYPE>& value,
                     const std::vector<VALUETYPE>& coord,
                     const std::vector<int>& atype,
                     const std::vector<VALUETYPE>& box,
                     const int nghost,
                     const InputNlist& inlist);
  template <typename VALUETYPE>
  void compute_inner(std::vector<VALUETYPE>& global_tensor,
                     std::vector<VALUETYPE>& force,
                     std::vector<VALUETYPE>& virial,
                     std::vector<VALUETYPE>& atom_tensor,
                     std::vector<VALUETYPE>& atom_virial,
                     const std::vector<VALUETYPE>& coord,
                     const std::vector<int>& atype,
                     const std::vector<VALUETYPE>& box);
  template <typename VALUETYPE>
  void compute_inner(std::vector<VALUETYPE>& global_tensor,
                     std::vector<VALUETYPE>& force,
                     std::vector<VALUETYPE>& virial,
                     std::vector<VALUETYPE>& atom_tensor,
                     std::vector<VALUETYPE>& atom_virial,
                     const std::vector<VALUETYPE>& coord,
                     const std::vector<int>& atype,
                     const std::vector<VALUETYPE>& box,
                     const int nghost,
                     const InputNlist& inlist);
};
}  // namespace deepmd