Commit d702530a authored by David M. Rogers's avatar David M. Rogers
Browse files

Added saxpy kernel test.

parent 1b49fc4a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ if(CMAKE_CUDA_ARCHITECTURES)
    set(ENABLE_GPU TRUE)
    include(setup_cuda)

    find_package(CUB CONFIG REQUIRED)
    find_package(Thrust REQUIRED CONFIG)
    thrust_create_target(Thrust HOST CPP DEVICE CUDA)
elseif(CMAKE_HIP_ARCHITECTURES)
+9 −1
Original line number Diff line number Diff line
@@ -8,6 +8,14 @@
#define MPITEST_VERSION_MAJOR @mpitest_VERSION_MAJOR@
#define MPITEST_VERSION_MINOR @mpitest_VERSION_MINOR@

int zero();
#ifdef ENABLE_GPU
#  define LINK_HOST __host__
#  define LINK_DEVICE __device__
#else
#  define LINK_HOST
#  define LINK_DEVICE
#endif
#define LINK_HOSTDEVICE LINK_HOST LINK_DEVICE

int zero();
#endif
+5 −0
Original line number Diff line number Diff line
add_executable(none none.cc)
add_test(NAME none COMMAND none)

add_executable(saxpy saxpy.cc)
target_link_libraries(saxpy PUBLIC mpiwrap)

add_test(NAME saxpy COMMAND saxpy)

tests/saxpy.cc

0 → 100644
+38 −0
Original line number Diff line number Diff line
#include <iostream>

#include <thrust/device_vector.h>
#include <thrust/transform.h>
#include <thrust/sequence.h>
#include <thrust/copy.h>
#include <thrust/fill.h>
#include <thrust/replace.h>
#include <thrust/functional.h>

#include "config.hh"

struct saxpy_functor {
    const float a;

    saxpy_functor(float _a) : a(_a) {}

    LINK_HOSTDEVICE
    float operator()(const float& x, const float& y) const { 
        return a * x + y;
    }
};

void saxpy_fast(float A, thrust::device_vector<float>& X,
                         thrust::device_vector<float>& Y) {
    // Y <- A * X + Y
    thrust::transform(X.begin(), X.end(),
                      Y.begin(), Y.begin(), saxpy_functor(A));
}

int main(int argc, char *argv[]) {
    const int N = 2048;
    thrust::device_vector<float> X(N);
    thrust::device_vector<float> Y(N);

    saxpy_fast(3.14159*2.0, X, Y);
    return 0;
}