Commit 91d13bdf authored by Kenneth Moreland's avatar Kenneth Moreland Committed by Kitware Robot
Browse files

Merge topic 'reduce-initial-type'

563e23aa Fix unintended cast in TBB Reduce's return value
a7100c84 Do not assume CUDA reduce operator is unary
f3a6931f Fix casting issues in TBB functors
cc5b9a01 Add casts to FunctorsGeneral.h
d9c988b2

 Allow for different types in basic type operators
Acked-by: default avatarKitware Robot <kwrobot@kitware.com>
Merge-request: !2431
parents eb682c79 563e23aa
......@@ -221,37 +221,81 @@ struct NullType
#endif // gcc || clang
struct Add
{
template <typename T, typename U>
inline VTKM_EXEC_CONT auto operator()(const T& a, const U& b) const -> decltype(a + b)
{
return a + b;
}
// If both arguments are short integers, explicitly cast the result back to the
// type to avoid narrowing conversion warnings from operations that promote to
// integers.
template <typename T>
inline VTKM_EXEC_CONT T operator()(const T& a, const T& b) const
inline VTKM_EXEC_CONT
typename std::enable_if<std::is_integral<T>::value && sizeof(T) < sizeof(int), T>::type
operator()(T a, T b) const
{
return T(a + b);
return static_cast<T>(a + b);
}
};
struct Subtract
{
template <typename T, typename U>
inline VTKM_EXEC_CONT auto operator()(const T& a, const U& b) const -> decltype(a - b)
{
return a - b;
}
// If both arguments are short integers, explicitly cast the result back to the
// type to avoid narrowing conversion warnings from operations that promote to
// integers.
template <typename T>
inline VTKM_EXEC_CONT T operator()(const T& a, const T& b) const
inline VTKM_EXEC_CONT
typename std::enable_if<std::is_integral<T>::value && sizeof(T) < sizeof(int), T>::type
operator()(T a, T b) const
{
return T(a - b);
return static_cast<T>(a - b);
}
};
struct Multiply
{
template <typename T, typename U>
inline VTKM_EXEC_CONT auto operator()(const T& a, const U& b) const -> decltype(a * b)
{
return a * b;
}
// If both arguments are short integers, explicitly cast the result back to the
// type to avoid narrowing conversion warnings from operations that promote to
// integers.
template <typename T>
inline VTKM_EXEC_CONT T operator()(const T& a, const T& b) const
inline VTKM_EXEC_CONT
typename std::enable_if<std::is_integral<T>::value && sizeof(T) < sizeof(int), T>::type
operator()(T a, T b) const
{
return T(a * b);
return static_cast<T>(a * b);
}
};
struct Divide
{
template <typename T, typename U>
inline VTKM_EXEC_CONT auto operator()(const T& a, const U& b) const -> decltype(a / b)
{
return a / b;
}
// If both arguments are short integers, explicitly cast the result back to the
// type to avoid narrowing conversion warnings from operations that promote to
// integers.
template <typename T>
inline VTKM_EXEC_CONT T operator()(const T& a, const T& b) const
inline VTKM_EXEC_CONT
typename std::enable_if<std::is_integral<T>::value && sizeof(T) < sizeof(int), T>::type
operator()(T a, T b) const
{
return T(a / b);
return static_cast<T>(a / b);
}
};
......
......@@ -178,8 +178,25 @@ __global__ void SumExclusiveScan(T a, T b, T result, BinaryOperationType binary_
#pragma GCC diagnostic pop
#endif
template <typename FunctorType, typename ArgType>
struct FunctorSupportsUnaryImpl
{
template <typename F, typename A, typename = decltype(std::declval<F>()(std::declval<A>()))>
static std::true_type has(int);
template <typename F, typename A>
static std::false_type has(...);
using type = decltype(has<FunctorType, ArgType>(0));
};
template <typename FunctorType, typename ArgType>
using FunctorSupportsUnary = typename FunctorSupportsUnaryImpl<FunctorType, ArgType>::type;
template <typename PortalType,
typename BinaryAndUnaryFunctor,
typename = FunctorSupportsUnary<BinaryAndUnaryFunctor, typename PortalType::ValueType>>
struct CastPortal;
template <typename PortalType, typename BinaryAndUnaryFunctor>
struct CastPortal
struct CastPortal<PortalType, BinaryAndUnaryFunctor, std::true_type>
{
using InputType = typename PortalType::ValueType;
using ValueType = decltype(std::declval<BinaryAndUnaryFunctor>()(std::declval<InputType>()));
......@@ -201,6 +218,28 @@ struct CastPortal
ValueType Get(vtkm::Id index) const { return this->Functor(this->Portal.Get(index)); }
};
template <typename PortalType, typename BinaryFunctor>
struct CastPortal<PortalType, BinaryFunctor, std::false_type>
{
using InputType = typename PortalType::ValueType;
using ValueType =
decltype(std::declval<BinaryFunctor>()(std::declval<InputType>(), std::declval<InputType>()));
PortalType Portal;
VTKM_CONT
CastPortal(const PortalType& portal, const BinaryFunctor&)
: Portal(portal)
{
}
VTKM_EXEC
vtkm::Id GetNumberOfValues() const { return this->Portal.GetNumberOfValues(); }
VTKM_EXEC
ValueType Get(vtkm::Id index) const { return static_cast<ValueType>(this->Portal.Get(index)); }
};
struct CudaFreeFunctor
{
void operator()(void* ptr) const { VTKM_CUDA_CALL(cudaFree(ptr)); }
......
......@@ -49,7 +49,7 @@ struct WrappedBinaryOperator
template <typename Argument1, typename Argument2>
VTKM_EXEC_CONT ResultType operator()(const Argument1& x, const Argument2& y) const
{
return m_f(x, y);
return static_cast<ResultType>(m_f(x, y));
}
VTKM_SUPPRESS_EXEC_WARNINGS
......@@ -60,7 +60,7 @@ struct WrappedBinaryOperator
{
using ValueTypeX = typename vtkm::internal::ArrayPortalValueReference<Argument1>::ValueType;
using ValueTypeY = typename vtkm::internal::ArrayPortalValueReference<Argument2>::ValueType;
return m_f((ValueTypeX)x, (ValueTypeY)y);
return static_cast<ResultType>(m_f((ValueTypeX)x, (ValueTypeY)y));
}
VTKM_SUPPRESS_EXEC_WARNINGS
......@@ -70,7 +70,7 @@ struct WrappedBinaryOperator
const vtkm::internal::ArrayPortalValueReference<Argument2>& y) const
{
using ValueTypeY = typename vtkm::internal::ArrayPortalValueReference<Argument2>::ValueType;
return m_f(x, (ValueTypeY)y);
return static_cast<ResultType>(m_f(x, (ValueTypeY)y));
}
VTKM_SUPPRESS_EXEC_WARNINGS
......@@ -80,7 +80,7 @@ struct WrappedBinaryOperator
const Argument2& y) const
{
using ValueTypeX = typename vtkm::internal::ArrayPortalValueReference<Argument1>::ValueType;
return m_f((ValueTypeX)x, y);
return static_cast<ResultType>(m_f((ValueTypeX)x, y));
}
};
......@@ -147,11 +147,11 @@ struct ReduceKernel : vtkm::exec::FunctorBase
{
//This will only occur for a single index value, so this is the case
//that needs to handle the initialValue
T partialSum = BinaryOperator(this->InitialValue, this->Portal.Get(offset));
T partialSum = static_cast<T>(BinaryOperator(this->InitialValue, this->Portal.Get(offset)));
vtkm::Id currentIndex = offset + 1;
while (currentIndex < this->PortalLength)
{
partialSum = BinaryOperator(partialSum, this->Portal.Get(currentIndex));
partialSum = static_cast<T>(BinaryOperator(partialSum, this->Portal.Get(currentIndex)));
++currentIndex;
}
return partialSum;
......@@ -160,10 +160,11 @@ struct ReduceKernel : vtkm::exec::FunctorBase
{
//optimize the usecase where all values are valid and we don't
//need to check that we might go out of bounds
T partialSum = BinaryOperator(this->Portal.Get(offset), this->Portal.Get(offset + 1));
T partialSum =
static_cast<T>(BinaryOperator(this->Portal.Get(offset), this->Portal.Get(offset + 1)));
for (int i = 2; i < reduceWidth; ++i)
{
partialSum = BinaryOperator(partialSum, this->Portal.Get(offset + i));
partialSum = static_cast<T>(BinaryOperator(partialSum, this->Portal.Get(offset + i)));
}
return partialSum;
}
......
......@@ -145,7 +145,8 @@ public:
}
template <typename T, typename U, class CIn>
VTKM_CONT static U Reduce(const vtkm::cont::ArrayHandle<T, CIn>& input, U initialValue)
VTKM_CONT static auto Reduce(const vtkm::cont::ArrayHandle<T, CIn>& input, U initialValue)
-> decltype(Reduce(input, initialValue, vtkm::Add{}))
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
......@@ -153,9 +154,10 @@ public:
}
template <typename T, typename U, class CIn, class BinaryFunctor>
VTKM_CONT static U Reduce(const vtkm::cont::ArrayHandle<T, CIn>& input,
U initialValue,
BinaryFunctor binary_functor)
VTKM_CONT static auto Reduce(const vtkm::cont::ArrayHandle<T, CIn>& input,
U initialValue,
BinaryFunctor binary_functor)
-> decltype(tbb::ReducePortals(input.ReadPortal(), initialValue, binary_functor))
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
......
......@@ -460,14 +460,17 @@ struct ReduceBody
template <class InputPortalType, typename T, class BinaryOperationType>
VTKM_CONT static T ReducePortals(InputPortalType inputPortal,
T initialValue,
BinaryOperationType binaryOperation)
VTKM_CONT static auto ReducePortals(InputPortalType inputPortal,
T initialValue,
BinaryOperationType binaryOperation)
-> decltype(binaryOperation(initialValue, inputPortal.Get(0)))
{
using WrappedBinaryOp = internal::WrappedBinaryOperator<T, BinaryOperationType>;
using ResultType = decltype(binaryOperation(initialValue, inputPortal.Get(0)));
using WrappedBinaryOp = internal::WrappedBinaryOperator<ResultType, BinaryOperationType>;
WrappedBinaryOp wrappedBinaryOp(binaryOperation);
ReduceBody<InputPortalType, T, WrappedBinaryOp> body(inputPortal, initialValue, wrappedBinaryOp);
ReduceBody<InputPortalType, ResultType, WrappedBinaryOp> body(
inputPortal, initialValue, wrappedBinaryOp);
vtkm::Id arrayLength = inputPortal.GetNumberOfValues();
if (arrayLength > 1)
......@@ -484,7 +487,7 @@ VTKM_CONT static T ReducePortals(InputPortalType inputPortal,
else // arrayLength == 0
{
// ReduceBody does not work with an array of size 0.
return initialValue;
return static_cast<ResultType>(initialValue);
}
}
......
......@@ -1332,17 +1332,17 @@ private:
//the output of reduce and scan inclusive should be the same
std::cout << " Reduce with initial value of 0." << std::endl;
vtkm::Id reduce_sum = Algorithm::Reduce(array, vtkm::Id(0));
vtkm::Id reduce_sum = Algorithm::Reduce(array, 0);
std::cout << " Reduce with initial value." << std::endl;
vtkm::Id reduce_sum_with_intial_value = Algorithm::Reduce(array, vtkm::Id(ARRAY_SIZE));
std::cout << " Inclusive scan to check" << std::endl;
vtkm::Id inclusive_sum = Algorithm::ScanInclusive(array, array);
std::cout << " Reduce with 1 value." << std::endl;
array.Allocate(1, vtkm::CopyFlag::On);
vtkm::Id reduce_sum_one_value = Algorithm::Reduce(array, vtkm::Id(0));
vtkm::Id reduce_sum_one_value = Algorithm::Reduce(array, 0);
std::cout << " Reduce with 0 values." << std::endl;
array.Allocate(0);
vtkm::Id reduce_sum_no_values = Algorithm::Reduce(array, vtkm::Id(0));
vtkm::Id reduce_sum_no_values = Algorithm::Reduce(array, 0);
VTKM_TEST_ASSERT(reduce_sum == OFFSET * ARRAY_SIZE, "Got bad sum from Reduce");
VTKM_TEST_ASSERT(reduce_sum_with_intial_value == reduce_sum + ARRAY_SIZE,
"Got bad sum from Reduce with initial value");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment