Program Listing for File DeepPot.h

Return to documentation for file (source/api_cc/include/DeepPot.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <memory>

#include "common.h"
#include "neighbor_list.h"

namespace deepmd {
class DeepPotBase {
 public:
  DeepPotBase(){};
  virtual ~DeepPotBase(){};
  DeepPotBase(const std::string& model,
              const int& gpu_rank = 0,
              const std::string& file_content = "");
  virtual void init(const std::string& model,
                    const int& gpu_rank = 0,
                    const std::string& file_content = "") = 0;

  virtual void computew(
      std::vector<double>& ener,
      std::vector<double>& force,
      std::vector<double>& virial,
      std::vector<double>& atom_energy,
      std::vector<double>& atom_virial,
      const std::vector<double>& coord,
      const std::vector<int>& atype,
      const std::vector<double>& box,
      const std::vector<double>& fparam = std::vector<double>(),
      const std::vector<double>& aparam = std::vector<double>()) = 0;
  virtual void computew(
      std::vector<double>& ener,
      std::vector<float>& force,
      std::vector<float>& virial,
      std::vector<float>& atom_energy,
      std::vector<float>& atom_virial,
      const std::vector<float>& coord,
      const std::vector<int>& atype,
      const std::vector<float>& box,
      const std::vector<float>& fparam = std::vector<float>(),
      const std::vector<float>& aparam = std::vector<float>()) = 0;
  virtual void computew(
      std::vector<double>& ener,
      std::vector<double>& force,
      std::vector<double>& virial,
      std::vector<double>& atom_energy,
      std::vector<double>& atom_virial,
      const std::vector<double>& coord,
      const std::vector<int>& atype,
      const std::vector<double>& box,
      const int nghost,
      const InputNlist& inlist,
      const int& ago,
      const std::vector<double>& fparam = std::vector<double>(),
      const std::vector<double>& aparam = std::vector<double>()) = 0;
  virtual void computew(
      std::vector<double>& ener,
      std::vector<float>& force,
      std::vector<float>& virial,
      std::vector<float>& atom_energy,
      std::vector<float>& atom_virial,
      const std::vector<float>& coord,
      const std::vector<int>& atype,
      const std::vector<float>& box,
      const int nghost,
      const InputNlist& inlist,
      const int& ago,
      const std::vector<float>& fparam = std::vector<float>(),
      const std::vector<float>& aparam = std::vector<float>()) = 0;
  virtual void computew_mixed_type(
      std::vector<double>& ener,
      std::vector<double>& force,
      std::vector<double>& virial,
      std::vector<double>& atom_energy,
      std::vector<double>& atom_virial,
      const int& nframes,
      const std::vector<double>& coord,
      const std::vector<int>& atype,
      const std::vector<double>& box,
      const std::vector<double>& fparam = std::vector<double>(),
      const std::vector<double>& aparam = std::vector<double>()) = 0;
  virtual void computew_mixed_type(
      std::vector<double>& ener,
      std::vector<float>& force,
      std::vector<float>& virial,
      std::vector<float>& atom_energy,
      std::vector<float>& atom_virial,
      const int& nframes,
      const std::vector<float>& coord,
      const std::vector<int>& atype,
      const std::vector<float>& box,
      const std::vector<float>& fparam = std::vector<float>(),
      const std::vector<float>& aparam = std::vector<float>()) = 0;
  virtual double cutoff() const = 0;
  virtual int numb_types() const = 0;
  virtual int numb_types_spin() const = 0;
  virtual int dim_fparam() const = 0;
  virtual int dim_aparam() const = 0;
  virtual void get_type_map(std::string& type_map) = 0;

  virtual bool is_aparam_nall() const = 0;
};

class DeepPot {
 public:
  DeepPot();
  ~DeepPot();
  DeepPot(const std::string& model,
          const int& gpu_rank = 0,
          const std::string& file_content = "");
  void init(const std::string& model,
            const int& gpu_rank = 0,
            const std::string& file_content = "");

  void print_summary(const std::string& pre) const;
  template <typename VALUETYPE>
  void compute(ENERGYTYPE& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(std::vector<ENERGYTYPE>& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(ENERGYTYPE& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& inlist,
               const int& ago,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(std::vector<ENERGYTYPE>& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& inlist,
               const int& ago,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(ENERGYTYPE& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_energy,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(std::vector<ENERGYTYPE>& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_energy,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(ENERGYTYPE& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_energy,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& lmp_list,
               const int& ago,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(std::vector<ENERGYTYPE>& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_energy,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& lmp_list,
               const int& ago,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute_mixed_type(
      ENERGYTYPE& ener,
      std::vector<VALUETYPE>& force,
      std::vector<VALUETYPE>& virial,
      const int& nframes,
      const std::vector<VALUETYPE>& coord,
      const std::vector<int>& atype,
      const std::vector<VALUETYPE>& box,
      const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
      const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute_mixed_type(
      std::vector<ENERGYTYPE>& ener,
      std::vector<VALUETYPE>& force,
      std::vector<VALUETYPE>& virial,
      const int& nframes,
      const std::vector<VALUETYPE>& coord,
      const std::vector<int>& atype,
      const std::vector<VALUETYPE>& box,
      const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
      const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute_mixed_type(
      ENERGYTYPE& ener,
      std::vector<VALUETYPE>& force,
      std::vector<VALUETYPE>& virial,
      std::vector<VALUETYPE>& atom_energy,
      std::vector<VALUETYPE>& atom_virial,
      const int& nframes,
      const std::vector<VALUETYPE>& coord,
      const std::vector<int>& atype,
      const std::vector<VALUETYPE>& box,
      const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
      const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute_mixed_type(
      std::vector<ENERGYTYPE>& ener,
      std::vector<VALUETYPE>& force,
      std::vector<VALUETYPE>& virial,
      std::vector<VALUETYPE>& atom_energy,
      std::vector<VALUETYPE>& atom_virial,
      const int& nframes,
      const std::vector<VALUETYPE>& coord,
      const std::vector<int>& atype,
      const std::vector<VALUETYPE>& box,
      const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
      const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  double cutoff() const;
  int numb_types() const;
  int numb_types_spin() const;
  int dim_fparam() const;
  int dim_aparam() const;
  void get_type_map(std::string& type_map);

  bool is_aparam_nall() const;

 private:
  bool inited;
  std::shared_ptr<deepmd::DeepPotBase> dp;
};

class DeepPotModelDevi {
 public:
  DeepPotModelDevi();
  ~DeepPotModelDevi();
  DeepPotModelDevi(const std::vector<std::string>& models,
                   const int& gpu_rank = 0,
                   const std::vector<std::string>& file_contents =
                       std::vector<std::string>());
  void init(const std::vector<std::string>& models,
            const int& gpu_rank = 0,
            const std::vector<std::string>& file_contents =
                std::vector<std::string>());

  template <typename VALUETYPE>
  void compute(std::vector<ENERGYTYPE>& all_ener,
               std::vector<std::vector<VALUETYPE> >& all_force,
               std::vector<std::vector<VALUETYPE> >& all_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& lmp_list,
               const int& ago,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE>
  void compute(std::vector<ENERGYTYPE>& all_ener,
               std::vector<std::vector<VALUETYPE> >& all_force,
               std::vector<std::vector<VALUETYPE> >& all_virial,
               std::vector<std::vector<VALUETYPE> >& all_atom_energy,
               std::vector<std::vector<VALUETYPE> >& all_atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& lmp_list,
               const int& ago,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  double cutoff() const {
    assert(inited);
    return dps[0].cutoff();
  };
  int numb_types() const {
    assert(inited);
    return dps[0].numb_types();
  };
  int numb_types_spin() const {
    assert(inited);
    return dps[0].numb_types_spin();
  };
  int dim_fparam() const {
    assert(inited);
    return dps[0].dim_fparam();
  };
  int dim_aparam() const {
    assert(inited);
    return dps[0].dim_aparam();
  };
  template <typename VALUETYPE>
  void compute_avg(VALUETYPE& dener, const std::vector<VALUETYPE>& all_energy);
  template <typename VALUETYPE>
  void compute_avg(std::vector<VALUETYPE>& avg,
                   const std::vector<std::vector<VALUETYPE> >& xx);
  template <typename VALUETYPE>
  void compute_std(std::vector<VALUETYPE>& std,
                   const std::vector<VALUETYPE>& avg,
                   const std::vector<std::vector<VALUETYPE> >& xx,
                   const int& stride);
  template <typename VALUETYPE>
  void compute_relative_std(std::vector<VALUETYPE>& std,
                            const std::vector<VALUETYPE>& avg,
                            const VALUETYPE eps,
                            const int& stride);
  template <typename VALUETYPE>
  void compute_std_e(std::vector<VALUETYPE>& std,
                     const std::vector<VALUETYPE>& avg,
                     const std::vector<std::vector<VALUETYPE> >& xx);
  template <typename VALUETYPE>
  void compute_std_f(std::vector<VALUETYPE>& std,
                     const std::vector<VALUETYPE>& avg,
                     const std::vector<std::vector<VALUETYPE> >& xx);
  template <typename VALUETYPE>
  void compute_relative_std_f(std::vector<VALUETYPE>& std,
                              const std::vector<VALUETYPE>& avg,
                              const VALUETYPE eps);
  bool is_aparam_nall() const {
    assert(inited);
    return dps[0].is_aparam_nall();
  };

 private:
  unsigned numb_models;
  std::vector<deepmd::DeepPot> dps;
  bool inited;
};
}  // namespace deepmd