Commit a7100c84 authored by Kenneth Moreland's avatar Kenneth Moreland
Browse files

Do not assume CUDA reduce operator is unary

The `Reduce` algorithm is sometimes used to convert an input type to a
different output type. For example, you can compute the min and max at
the same time by making the output of the binary functor a pair of the
input type. However, for this to work with the CUDA algorithm, you have
to be able to also convert the input type to the output type. This was
previously done by treating the binary operator as also a unary
operator. That's fine for custom operators, but if you are using
something like `thrust::plus`, it has no unary operation. (Why would

So, detect whether the operator has a unary operation. If it does, use
it to cast from the input portal to the output type. If it does not,
just use `static_cast`. Thus, the operator only has to have the unary
operation if `static_cast` does not work.
parent f3a6931f
......@@ -178,8 +178,25 @@ __global__ void SumExclusiveScan(T a, T b, T result, BinaryOperationType binary_
#pragma GCC diagnostic pop
template <typename FunctorType, typename ArgType>
struct FunctorSupportsUnaryImpl
template <typename F, typename A, typename = decltype(std::declval<F>()(std::declval<A>()))>
static std::true_type has(int);
template <typename F, typename A>
static std::false_type has(...);
using type = decltype(has<FunctorType, ArgType>(0));
template <typename FunctorType, typename ArgType>
using FunctorSupportsUnary = typename FunctorSupportsUnaryImpl<FunctorType, ArgType>::type;
template <typename PortalType,
typename BinaryAndUnaryFunctor,
typename = FunctorSupportsUnary<BinaryAndUnaryFunctor, typename PortalType::ValueType>>
struct CastPortal;
template <typename PortalType, typename BinaryAndUnaryFunctor>
struct CastPortal
struct CastPortal<PortalType, BinaryAndUnaryFunctor, std::true_type>
using InputType = typename PortalType::ValueType;
using ValueType = decltype(std::declval<BinaryAndUnaryFunctor>()(std::declval<InputType>()));
......@@ -201,6 +218,28 @@ struct CastPortal
ValueType Get(vtkm::Id index) const { return this->Functor(this->Portal.Get(index)); }
template <typename PortalType, typename BinaryFunctor>
struct CastPortal<PortalType, BinaryFunctor, std::false_type>
using InputType = typename PortalType::ValueType;
using ValueType =
decltype(std::declval<BinaryFunctor>()(std::declval<InputType>(), std::declval<InputType>()));
PortalType Portal;
CastPortal(const PortalType& portal, const BinaryFunctor&)
: Portal(portal)
vtkm::Id GetNumberOfValues() const { return this->Portal.GetNumberOfValues(); }
ValueType Get(vtkm::Id index) const { return static_cast<ValueType>(this->Portal.Get(index)); }
struct CudaFreeFunctor
void operator()(void* ptr) const { VTKM_CUDA_CALL(cudaFree(ptr)); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment