Skip to content
Snippets Groups Projects

WIP: Index vectorization

Closed Slattery, Stuart requested to merge (removed):index_opt into master
+ 51
21
@@ -270,9 +270,6 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
// SoA type.
using soa_type = SoA<inner_array_layout,Types...>;
// View type.
using soa_view_type = Kokkos::View<soa_type*,typename traits::memory_space>;
// Member data types.
using member_types = MemberDataTypes<Types...>;
@@ -338,7 +335,8 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
: _size( 0 )
, _capacity( 0 )
, _num_soa( 0 )
, _data( "soa_view", 0 )
, _managed_data( nullptr )
, _data( nullptr )
{
storeRanksAndExtents(
std::integral_constant<std::size_t,number_of_members-1>() );
@@ -353,7 +351,8 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
: _size( n )
, _capacity( 0 )
, _num_soa( 0 )
, _data( "soa_view", 0 )
, _managed_data( nullptr )
, _data( nullptr )
{
resize( _size );
storeRanksAndExtents(
@@ -438,8 +437,33 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
if ( 0 < n % array_size ) ++num_soa_alloc;
_capacity = num_soa_alloc * array_size;
// Resize the view.
Kokkos::resize( _data, num_soa_alloc );
// Allocate a new block of memory.
std::shared_ptr<soa_type> sp(
(soa_type*) Kokkos::kokkos_malloc<typename traits::memory_space>(
num_soa_alloc * sizeof(soa_type)),
Kokkos::kokkos_free<typename traits::memory_space> );
// Fence before continuing to ensure the allocation is completed.
Kokkos::fence();
// If we have already allocated memory, copy the old memory into the
// new memory. Fence when we are done to ensure copy is complete
// before continuing.
if ( _managed_data != nullptr )
{
Kokkos::Impl::DeepCopy<
typename traits::memory_space,
typename traits::memory_space,
typename traits::execution_space>(
sp.get(), _managed_data.get(), _num_soa * sizeof(soa_type) );
Kokkos::fence();
}
// Swap blocks. The old block will be destroyed when this function exits.
std::swap( _managed_data, sp );
// Assign the raw data pointer.
_data = _managed_data.get();
// Get new pointers and strides for the members.
storePointersAndStrides(
@@ -548,7 +572,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
get( const int particle_index ) const
{
return accessStructMember<M>(
_data(Impl::Index<array_size>::s(particle_index)),
_data[Impl::Index<array_size>::s(particle_index)],
Impl::Index<array_size>::i(particle_index) );
}
@@ -561,7 +585,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
const int d0 ) const
{
return accessStructMember<M>(
_data(Impl::Index<array_size>::s(particle_index)),
_data[Impl::Index<array_size>::s(particle_index)],
Impl::Index<array_size>::i(particle_index),
d0 );
}
@@ -576,7 +600,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
const int d1 ) const
{
return accessStructMember<M>(
_data(Impl::Index<array_size>::s(particle_index)),
_data[Impl::Index<array_size>::s(particle_index)],
Impl::Index<array_size>::i(particle_index),
d0, d1 );
}
@@ -592,7 +616,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
const int d2 ) const
{
return accessStructMember<M>(
_data(Impl::Index<array_size>::s(particle_index)),
_data[Impl::Index<array_size>::s(particle_index)],
Impl::Index<array_size>::i(particle_index),
d0, d1, d2 );
}
@@ -609,7 +633,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
const int d3 ) const
{
return accessStructMember<M>(
_data(Impl::Index<array_size>::s(particle_index)),
_data[Impl::Index<array_size>::s(particle_index)],
Impl::Index<array_size>::i(particle_index),
d0, d1, d2, d3 );
}
@@ -624,7 +648,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
struct_member_reference_type<M> >::type
access( const int struct_index, const int array_index ) const
{
return accessStructMember<M>( _data(struct_index), array_index );
return accessStructMember<M>( _data[struct_index], array_index );
}
// Rank 1
@@ -635,7 +659,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
access( const int struct_index, const int array_index,
const int d0 ) const
{
return accessStructMember<M>( _data(struct_index), array_index, d0 );
return accessStructMember<M>( _data[struct_index], array_index, d0 );
}
// Rank 2
@@ -647,7 +671,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
const int d0,
const int d1 ) const
{
return accessStructMember<M>( _data(struct_index), array_index, d0, d1 );
return accessStructMember<M>( _data[struct_index], array_index, d0, d1 );
}
// Rank 3
@@ -660,7 +684,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
const int d1,
const int d2 ) const
{
return accessStructMember<M>( _data(struct_index), array_index, d0, d1, d2 );
return accessStructMember<M>( _data[struct_index], array_index, d0, d1, d2 );
}
// Rank 4
@@ -674,7 +698,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
const int d2,
const int d3 ) const
{
return accessStructMember<M>( _data(struct_index), array_index, d0, d1, d2, d3 );
return accessStructMember<M>( _data[struct_index], array_index, d0, d1, d2, d3 );
}
// -------------------------------
@@ -719,7 +743,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
*/
void* ptr() const
{
return _data.data();
return _managed_data.get();
}
private:
@@ -731,7 +755,7 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
static_assert( 0 <= N && N < number_of_members,
"Static loop out of bounds!" );
_pointers[N] =
static_cast<void*>( getStructMember<N>(_data(0)) );
static_cast<void*>( getStructMember<N>(_data[0]) );
static_assert( 0 ==
sizeof(soa_type) % sizeof(struct_member_value_type<N>),
"Stride cannot be calculated for misaligned memory!" );
@@ -990,8 +1014,14 @@ class AoSoA<MemberDataTypes<Types...>,Properties...>
// Number of structs-of-arrays in the array.
int _num_soa;
// Kokkos view of SoAs
soa_view_type _data;
// Structs-of-Arrays managed data. This shared pointer manages the block
// of memory owned by this class such that the copy constructor and
// assignment operator for this class perform a shallow and reference
// counted copy of the data.
std::shared_ptr<soa_type> _managed_data;
// Raw pointer to the managed data.
soa_type* _data;
// Pointers to the first element of each member.
void* _pointers[number_of_members];
Loading