Program Listing for File DeepTensor.h

Return to documentation for file (source/api_cc/include/DeepTensor.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <memory>

#include "common.h"
#include "neighbor_list.h"

namespace deepmd {
class DeepTensorBase {
 public:
  DeepTensorBase(){};
  virtual ~DeepTensorBase(){};
  DeepTensorBase(const std::string& model,
                 const int& gpu_rank = 0,
                 const std::string& name_scope = "");
  virtual void init(const std::string& model,
                    const int& gpu_rank = 0,
                    const std::string& name_scope = "") = 0;
  virtual void computew(std::vector<double>& global_tensor,
                        std::vector<double>& force,
                        std::vector<double>& virial,
                        std::vector<double>& atom_tensor,
                        std::vector<double>& atom_virial,
                        const std::vector<double>& coord,
                        const std::vector<int>& atype,
                        const std::vector<double>& box,
                        const bool request_deriv) = 0;
  virtual void computew(std::vector<float>& global_tensor,
                        std::vector<float>& force,
                        std::vector<float>& virial,
                        std::vector<float>& atom_tensor,
                        std::vector<float>& atom_virial,
                        const std::vector<float>& coord,
                        const std::vector<int>& atype,
                        const std::vector<float>& box,
                        const bool request_deriv) = 0;
  virtual void computew(std::vector<double>& global_tensor,
                        std::vector<double>& force,
                        std::vector<double>& virial,
                        std::vector<double>& atom_tensor,
                        std::vector<double>& atom_virial,
                        const std::vector<double>& coord,
                        const std::vector<int>& atype,
                        const std::vector<double>& box,
                        const int nghost,
                        const InputNlist& inlist,
                        const bool request_deriv) = 0;
  virtual void computew(std::vector<float>& global_tensor,
                        std::vector<float>& force,
                        std::vector<float>& virial,
                        std::vector<float>& atom_tensor,
                        std::vector<float>& atom_virial,
                        const std::vector<float>& coord,
                        const std::vector<int>& atype,
                        const std::vector<float>& box,
                        const int nghost,
                        const InputNlist& inlist,
                        const bool request_deriv) = 0;
  virtual double cutoff() const = 0;
  virtual int numb_types() const = 0;
  virtual int output_dim() const = 0;
  virtual const std::vector<int>& sel_types() const = 0;
  virtual void get_type_map(std::string& type_map) = 0;
};

class DeepTensor {
 public:
  DeepTensor();
  ~DeepTensor();
  DeepTensor(const std::string& model,
             const int& gpu_rank = 0,
             const std::string& name_scope = "");
  void init(const std::string& model,
            const int& gpu_rank = 0,
            const std::string& name_scope = "");
  void print_summary(const std::string& pre) const;

  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& value,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& value,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& inlist);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& global_tensor,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& global_tensor,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& inlist);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& global_tensor,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_tensor,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box);
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& global_tensor,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_tensor,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& inlist);
  double cutoff() const;
  int numb_types() const;
  int output_dim() const;
  const std::vector<int>& sel_types() const;
  void get_type_map(std::string& type_map);

 private:
  bool inited;
  std::shared_ptr<deepmd::DeepTensorBase> dt;
};
}  // namespace deepmd