Commit ebc7b38c authored by Berrill, Mark's avatar Berrill, Mark
Browse files

Fixing minor bug with HIP

parent 9e39bb2a
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -11,10 +11,10 @@
#endif

#ifdef USE_CUDA
#include <cuda_runtime_api.h>
extern int getCudaDeviceCount();
#endif
#ifdef USE_HIP
#include "hip/hip_runtime.h"
extern int getHipDeviceCount();
#endif

// Detect the OS
@@ -259,9 +259,9 @@ void printHardware()
    // Get number of gpus
    int N_gpu = 0;
#ifdef USE_CUDA
    cudaGetDeviceCount( &N_gpu );
    N_gpu = getCudaDeviceCount();
#elif defined( USE_HIP )
    hipGetDeviceCount( &N_gpu );
    N_gpu         = getHipDeviceCount();
#endif
    // Get system memory
#if defined( USE_LINUX )
+8 −16
Original line number Diff line number Diff line
@@ -26,19 +26,20 @@
#endif
#ifdef USE_CUDA
#define ENABLE_CUDA
extern int getCudaDeviceCount();
extern void setCudaGPU( int );
#undef USE_CUDA
#endif
#ifdef USE_HIP
#define ENABLE_HIP
extern int getHipDeviceCount();
extern void setHipGPU( int );
#undef USE_HIP
#endif
#ifdef USE_OPENMP
#define ENABLE_OPENMP
#undef USE_OPENMP
#endif
#ifdef ENABLE_CUDA
#include <cuda_runtime_api.h>
#endif
#include "RayTrace/common/RayTraceDefinitions.h"
#include "RayTrace/common/RayTraceImageHelper.h"
#include "RayTrace/utilities/RayUtilityMacros.h"
@@ -92,13 +93,6 @@ static inline std::vector<double> createGrid( size_t N, double dx, const double
/**********************************************************************
 * Call RayTraceImage function from a thread loop                      *
 **********************************************************************/
void setGPU( int id )
{
    NULL_USE( id );
#if defined( ENABLE_CUDA )
    cudaSetDevice( id );
#endif
}
void setDeviceAndRun( int id, int N_threads, std::function<void( int )> setID,
    std::function<void( int, const RayTrace::EUV_beam_struct &, const RayTrace::ray_gain_struct *,
        const RayTrace::ray_seed_struct *, int, const std::vector<ray_struct> &, double, double *,
@@ -359,9 +353,8 @@ void RayTrace::create_image( create_image_struct *info, std::string compute_meth
#endif
    } else if ( compute_method == "cuda-multigpu" ) {
#if defined( ENABLE_CUDA )
        int N_gpu;
        cudaGetDeviceCount( &N_gpu );
        RayTraceImageThreadLoop( N_gpu, RayTraceImageCudaLoop, setGPU, N,
        int N_gpu = getCudaDeviceCount();
        RayTraceImageThreadLoop( N_gpu, RayTraceImageCudaLoop, setCudaGPU, N,
            std::ref( *info->euv_beam ), info->gain, info->seed, method, rays, scale, image, I_ang,
            failure_code, failed_rays );
#else
@@ -376,9 +369,8 @@ void RayTrace::create_image( create_image_struct *info, std::string compute_meth
#endif
    } else if ( compute_method == "hip-multigpu" ) {
#if defined( ENABLE_HIP )
        int N_gpu;
        hipGetDeviceCount( &N_gpu );
        RayTraceImageThreadLoop( N_gpu, RayTraceImageHIPLoop, setGPU, N,
        int N_gpu = getHipDeviceCount();
        RayTraceImageThreadLoop( N_gpu, RayTraceImageHIPLoop, setHipGPU, N,
            std::ref( *info->euv_beam ), info->gain, info->seed, method, rays, scale, image, I_ang,
            failure_code, failed_rays );
#else
+12 −0
Original line number Diff line number Diff line
@@ -78,6 +78,18 @@ __device__ inline int getIndex( int n, const double x_range[2], double dx, doubl
}


// Helper functions
int getCudaDeviceCount()
{
    int N_gpu = 0;
    cudaGetDeviceCount( &N_gpu );
    return N_gpu;
}
void setCudaGPU( int id )
{
    cudaSetDevice( id );
}


// Kernel that executes on the CUDA device
__global__
+13 −0
Original line number Diff line number Diff line
@@ -45,6 +45,19 @@ __device__ inline int getIndex( int n, const double x_range[2], double dx, doubl
}


// Helper functions
int getHipDeviceCount()
{
    int N_gpu = 0;
    hipGetDeviceCount( &N_gpu );
    return N_gpu;
}
void setHipGPU( int id )
{
    hipSetDevice( id );
}


// Kernel that executes on the device
__global__ __launch_bounds__( 128, 8 ) // Set bounds to limit the number of registers
    void RayTraceImageHIPKernel( int N, int nx, int ny, int na, int nb, int nv, double x0,