Program Listing for File gpu_rocm.h
↰ Return to documentation for file (source/lib/include/gpu_rocm.h
)
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once
#include <assert.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <string>
#include <vector>
// #include<rocprim/rocprim.hpp>
// #include <hipcub/hipcub.hpp>
#include "errors.h"
#define GPU_MAX_NBOR_SIZE 4096
#define gpuGetLastError hipGetLastError
#define gpuDeviceSynchronize hipDeviceSynchronize
#define gpuMemcpy hipMemcpy
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define gpuMemset hipMemset
#define DPErrcheck(res) \
{ DPAssert((res), __FILE__, __LINE__); }
inline void DPAssert(hipError_t code,
const char *file,
int line,
bool abort = true) {
if (code != hipSuccess) {
std::string error_msg = "HIP runtime library throws an error: " +
std::string(hipGetErrorString(code)) +
", in file " + std::string(file) + ": " +
std::to_string(line);
if (abort) {
throw deepmd::deepmd_exception(error_msg);
} else {
fprintf(stderr, "%s\n", error_msg.c_str());
}
}
}
#define nborErrcheck(res) \
{ nborAssert((res), __FILE__, __LINE__); }
inline void nborAssert(hipError_t code,
const char *file,
int line,
bool abort = true) {
if (code != hipSuccess) {
std::string error_msg = "DeePMD-kit: Illegal nbor list sorting: ";
try {
DPAssert(code, file, line, true);
} catch (deepmd::deepmd_exception &e) {
error_msg += e.what();
if (abort) {
throw deepmd::deepmd_exception(error_msg);
} else {
fprintf(stderr, "%s\n", error_msg.c_str());
}
}
}
}
namespace deepmd {
inline void DPGetDeviceCount(int &gpu_num) { hipGetDeviceCount(&gpu_num); }
inline hipError_t DPSetDevice(int rank) { return hipSetDevice(rank); }
template <typename FPTYPE>
void memcpy_host_to_device(FPTYPE *device, std::vector<FPTYPE> &host) {
DPErrcheck(hipMemcpy(device, &host[0], sizeof(FPTYPE) * host.size(),
hipMemcpyHostToDevice));
}
template <typename FPTYPE>
void memcpy_host_to_device(FPTYPE *device, const FPTYPE *host, const int size) {
DPErrcheck(
hipMemcpy(device, host, sizeof(FPTYPE) * size, hipMemcpyHostToDevice));
}
template <typename FPTYPE>
void memcpy_device_to_host(const FPTYPE *device, std::vector<FPTYPE> &host) {
DPErrcheck(hipMemcpy(&host[0], device, sizeof(FPTYPE) * host.size(),
hipMemcpyDeviceToHost));
}
template <typename FPTYPE>
void memcpy_device_to_host(const FPTYPE *device, FPTYPE *host, const int size) {
DPErrcheck(
hipMemcpy(host, device, sizeof(FPTYPE) * size, hipMemcpyDeviceToHost));
}
template <typename FPTYPE>
void malloc_device_memory(FPTYPE *&device, std::vector<FPTYPE> &host) {
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
}
template <typename FPTYPE>
void malloc_device_memory(FPTYPE *&device, const int size) {
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * size));
}
template <typename FPTYPE>
void malloc_device_memory_sync(FPTYPE *&device, std::vector<FPTYPE> &host) {
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
memcpy_host_to_device(device, host);
}
template <typename FPTYPE>
void malloc_device_memory_sync(FPTYPE *&device,
const FPTYPE *host,
const int size) {
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * size));
memcpy_host_to_device(device, host, size);
}
template <typename FPTYPE>
void delete_device_memory(FPTYPE *&device) {
if (device != NULL) {
DPErrcheck(hipFree(device));
}
}
template <typename FPTYPE>
void memset_device_memory(FPTYPE *device, const int var, const int size) {
DPErrcheck(hipMemset(device, var, sizeof(FPTYPE) * size));
}
} // namespace deepmd