Program Listing for File DeepPotPT.h
↰ Return to documentation for file (source/api_cc/include/DeepPotPT.h
)
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once
#include <torch/script.h>
#include <torch/torch.h>
#include "DeepPot.h"
namespace deepmd {
class DeepPotPT : public DeepPotBase {
public:
DeepPotPT();
~DeepPotPT();
DeepPotPT(const std::string& model,
const int& gpu_rank = 0,
const std::string& file_content = "");
void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& file_content = "");
private:
template <typename VALUETYPE, typename ENERGYVTYPE>
void compute(ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_energy,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
template <typename VALUETYPE, typename ENERGYVTYPE>
void compute(ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_energy,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
template <typename VALUETYPE, typename ENERGYVTYPE>
void compute_mixed_type(
ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
const int& nframes,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
template <typename VALUETYPE, typename ENERGYVTYPE>
void compute_mixed_type(
ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_energy,
std::vector<VALUETYPE>& atom_virial,
const int& nframes,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
public:
double cutoff() const {
assert(inited);
return rcut;
};
int numb_types() const {
assert(inited);
return ntypes;
};
int numb_types_spin() const {
assert(inited);
return ntypes_spin;
};
int dim_fparam() const {
assert(inited);
return dfparam;
};
int dim_aparam() const {
assert(inited);
return daparam;
};
void get_type_map(std::string& type_map);
bool is_aparam_nall() const {
assert(inited);
return aparam_nall;
};
// forward to template class
void computew(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const std::vector<double>& fparam = std::vector<double>(),
const std::vector<double>& aparam = std::vector<double>());
void computew(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const std::vector<float>& fparam = std::vector<float>(),
const std::vector<float>& aparam = std::vector<float>());
void computew(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const int nghost,
const InputNlist& inlist,
const int& ago,
const std::vector<double>& fparam = std::vector<double>(),
const std::vector<double>& aparam = std::vector<double>());
void computew(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const int nghost,
const InputNlist& inlist,
const int& ago,
const std::vector<float>& fparam = std::vector<float>(),
const std::vector<float>& aparam = std::vector<float>());
void computew_mixed_type(
std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const int& nframes,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const std::vector<double>& fparam = std::vector<double>(),
const std::vector<double>& aparam = std::vector<double>());
void computew_mixed_type(
std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const int& nframes,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const std::vector<float>& fparam = std::vector<float>(),
const std::vector<float>& aparam = std::vector<float>());
private:
int num_intra_nthreads, num_inter_nthreads;
bool inited;
int ntypes;
int ntypes_spin;
int dfparam;
int daparam;
bool aparam_nall;
// copy neighbor list info from host
torch::jit::script::Module module;
double rcut;
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
int do_message_passing; // 1:dpa2 model 0:others
bool gpu_enabled;
at::Tensor firstneigh_tensor;
torch::Dict<std::string, torch::Tensor> comm_dict;
};
} // namespace deepmd