Add the GPU rigid body pipeline from https://github.com/erwincoumans/experiments as a Bullet 3.x preview for Bullet 2.80
This commit is contained in:
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Adl/Adl.h>
|
||||
#include <AdlPrimitives/Math/Math.h>
|
||||
#include <AdlPrimitives/Sort/SortData.h>
|
||||
#include <AdlPrimitives/Scan/PrefixScan.h>
|
||||
|
||||
namespace adl
|
||||
{
|
||||
|
||||
class RadixSortBase
|
||||
{
|
||||
public:
|
||||
enum Option
|
||||
{
|
||||
SORT_SIMPLE,
|
||||
SORT_STANDARD,
|
||||
SORT_ADVANCED
|
||||
};
|
||||
};
|
||||
|
||||
template<DeviceType TYPE>
|
||||
class RadixSort : public RadixSortBase
|
||||
{
|
||||
public:
|
||||
struct Data
|
||||
{
|
||||
Option m_option;
|
||||
const Device* m_deviceData;
|
||||
typename PrefixScan<TYPE>::Data* m_scanData;
|
||||
int m_maxSize;
|
||||
};
|
||||
|
||||
|
||||
static
|
||||
Data* allocate(const Device* deviceData, int maxSize, Option option = SORT_STANDARD);
|
||||
|
||||
static
|
||||
void deallocate(Data* data);
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<SortData>& inout, int n, int sortBits = 32);
|
||||
};
|
||||
|
||||
|
||||
#include <AdlPrimitives/Sort/RadixSort.inl>
|
||||
#include <AdlPrimitives/Sort/RadixSortHost.inl>
|
||||
|
||||
};
|
||||
@@ -0,0 +1,58 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#include <AdlPrimitives/Sort/RadixSortSimple.inl>
|
||||
#include <AdlPrimitives/Sort/RadixSortStandard.inl>
|
||||
#include <AdlPrimitives/Sort/RadixSortAdvanced.inl>
|
||||
|
||||
|
||||
#define DISPATCH_IMPL(x) \
|
||||
switch( data->m_option ) \
|
||||
{ \
|
||||
case SORT_SIMPLE: RadixSortSimple<TYPE>::x; break; \
|
||||
case SORT_STANDARD: RadixSortStandard<TYPE>::x; break; \
|
||||
case SORT_ADVANCED: RadixSortAdvanced<TYPE>::x; break; \
|
||||
default:ADLASSERT(0);break; \
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
typename RadixSort<TYPE>::Data* RadixSort<TYPE>::allocate(const Device* deviceData, int maxSize, Option option)
|
||||
{
|
||||
ADLASSERT( TYPE == deviceData->m_type );
|
||||
|
||||
void* dataOut;
|
||||
switch( option )
|
||||
{
|
||||
case SORT_SIMPLE:
|
||||
dataOut = RadixSortSimple<TYPE>::allocate( deviceData, maxSize, option );
|
||||
break;
|
||||
case SORT_STANDARD:
|
||||
dataOut = RadixSortStandard<TYPE>::allocate( deviceData, maxSize, option );
|
||||
break;
|
||||
case SORT_ADVANCED:
|
||||
dataOut = RadixSortAdvanced<TYPE>::allocate( deviceData, maxSize, option );
|
||||
break;
|
||||
default:
|
||||
ADLASSERT(0);
|
||||
break;
|
||||
}
|
||||
return (typename RadixSort<TYPE>::Data*)dataOut;
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
void RadixSort<TYPE>::deallocate(Data* data)
|
||||
{
|
||||
DISPATCH_IMPL( deallocate( data ) );
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
void RadixSort<TYPE>::execute(Data* data, Buffer<SortData>& inout, int n, int sortBits)
|
||||
{
|
||||
DISPATCH_IMPL( execute( data, inout, n, sortBits ) );
|
||||
}
|
||||
|
||||
|
||||
#undef DISPATCH_IMPL
|
||||
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Adl/Adl.h>
|
||||
#include <AdlPrimitives/Math/Math.h>
|
||||
#include <AdlPrimitives/Copy/Copy.h>
|
||||
#include <AdlPrimitives/Sort/SortData.h>
|
||||
|
||||
namespace adl
|
||||
{
|
||||
|
||||
class RadixSort32Base
|
||||
{
|
||||
public:
|
||||
// enum Option
|
||||
// {
|
||||
// SORT_SIMPLE,
|
||||
// SORT_STANDARD,
|
||||
// SORT_ADVANCED
|
||||
// };
|
||||
};
|
||||
|
||||
template<DeviceType TYPE>
|
||||
class RadixSort32 : public RadixSort32Base
|
||||
{
|
||||
public:
|
||||
typedef Launcher::BufferInfo BufferInfo;
|
||||
|
||||
enum
|
||||
{
|
||||
DATA_ALIGNMENT = 256,
|
||||
WG_SIZE = 64,
|
||||
ELEMENTS_PER_WORK_ITEM = (256/WG_SIZE),
|
||||
BITS_PER_PASS = 4,
|
||||
|
||||
// if you change this, change nPerWI in kernel as well
|
||||
NUM_WGS = 20*6, // cypress
|
||||
// NUM_WGS = 24*6, // cayman
|
||||
// NUM_WGS = 32*4, // nv
|
||||
};
|
||||
|
||||
struct ConstData
|
||||
{
|
||||
int m_n;
|
||||
int m_nWGs;
|
||||
int m_startBit;
|
||||
int m_nBlocksPerWG;
|
||||
};
|
||||
|
||||
struct Data
|
||||
{
|
||||
const Device* m_device;
|
||||
int m_maxSize;
|
||||
|
||||
Kernel* m_streamCountKernel;
|
||||
Kernel* m_streamCountSortDataKernel;
|
||||
Kernel* m_prefixScanKernel;
|
||||
Kernel* m_sortAndScatterKernel;
|
||||
Kernel* m_sortAndScatterKeyValueKernel;
|
||||
Kernel* m_sortAndScatterSortDataKernel;
|
||||
|
||||
Buffer<u32>* m_workBuffer0;
|
||||
Buffer<u32>* m_workBuffer1;
|
||||
Buffer<u32>* m_workBuffer2;
|
||||
Buffer<SortData>* m_workBuffer3;
|
||||
|
||||
Buffer<ConstData>* m_constBuffer[32/BITS_PER_PASS];
|
||||
|
||||
typename Copy<TYPE>::Data* m_copyData;
|
||||
};
|
||||
|
||||
static
|
||||
Data* allocate(const Device* device, int maxSize);
|
||||
|
||||
static
|
||||
void deallocate(Data* data);
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<u32>& inout, int n, int sortBits = 32);
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<u32>& in, Buffer<u32>& out, int n, int sortBits = 32);
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<u32>& keysIn, Buffer<u32>& keysOut, Buffer<u32>& valuesIn, Buffer<u32>& valuesOut, int n, int sortBits = 32);
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<SortData>& keyValuesInOut, int n, int sortBits = 32 );
|
||||
};
|
||||
|
||||
|
||||
#include <AdlPrimitives/Sort/RadixSort32Host.inl>
|
||||
#include <AdlPrimitives/Sort/RadixSort32.inl>
|
||||
|
||||
};
|
||||
@@ -0,0 +1,346 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#define PATH "..\\..\\opencl\\primitives\\AdlPrimitives\\Sort\\RadixSort32Kernels"
|
||||
#define RADIXSORT32_KERNEL0 "StreamCountKernel"
|
||||
#define RADIXSORT32_KERNEL1 "PrefixScanKernel"
|
||||
#define RADIXSORT32_KERNEL2 "SortAndScatterKernel"
|
||||
#define RADIXSORT32_KERNEL3 "SortAndScatterKeyValueKernel"
|
||||
#define RADIXSORT32_KERNEL4 "SortAndScatterSortDataKernel"
|
||||
#define RADIXSORT32_KERNEL5 "StreamCountSortDataKernel"
|
||||
|
||||
#include "RadixSort32KernelsCL.h"
|
||||
#include "RadixSort32KernelsDX11.h"
|
||||
|
||||
// todo. Shader compiler (2010JuneSDK) doesn't allow me to place Barriers in SortAndScatterKernel...
|
||||
// So it only works on a GPU with 64 wide SIMD.
|
||||
|
||||
template<DeviceType TYPE>
|
||||
typename RadixSort32<TYPE>::Data* RadixSort32<TYPE>::allocate( const Device* device, int maxSize )
|
||||
{
|
||||
ADLASSERT( TYPE == device->m_type );
|
||||
|
||||
const char* src[] =
|
||||
#if defined(ADL_LOAD_KERNEL_FROM_STRING)
|
||||
{radixSort32KernelsCL, radixSort32KernelsDX11};
|
||||
#else
|
||||
{0,0};
|
||||
#endif
|
||||
|
||||
Data* data = new Data;
|
||||
data->m_device = device;
|
||||
data->m_maxSize = maxSize;
|
||||
data->m_streamCountKernel = device->getKernel( PATH, RADIXSORT32_KERNEL0, 0, src[TYPE] );
|
||||
data->m_streamCountSortDataKernel = device->getKernel( PATH, RADIXSORT32_KERNEL5, 0, src[TYPE] );
|
||||
|
||||
|
||||
|
||||
data->m_prefixScanKernel = device->getKernel( PATH, RADIXSORT32_KERNEL1, 0, src[TYPE] );
|
||||
data->m_sortAndScatterKernel = device->getKernel( PATH, RADIXSORT32_KERNEL2, 0, src[TYPE] );
|
||||
data->m_sortAndScatterKeyValueKernel = device->getKernel( PATH, RADIXSORT32_KERNEL3, 0, src[TYPE] );
|
||||
data->m_sortAndScatterSortDataKernel = device->getKernel( PATH, RADIXSORT32_KERNEL4, 0, src[TYPE] );
|
||||
|
||||
int wtf = NUM_WGS*(1<<BITS_PER_PASS);
|
||||
|
||||
data->m_workBuffer0 = new Buffer<u32>( device, maxSize );
|
||||
data->m_workBuffer1 = new Buffer<u32>( device , wtf );
|
||||
data->m_workBuffer2 = new Buffer<u32>( device, maxSize );
|
||||
data->m_workBuffer3 = new Buffer<SortData>(device,maxSize);
|
||||
|
||||
|
||||
for(int i=0; i<32/BITS_PER_PASS; i++)
|
||||
data->m_constBuffer[i] = new Buffer<ConstData>( device, 1, BufferBase::BUFFER_CONST );
|
||||
|
||||
data->m_copyData = Copy<TYPE>::allocate( device );
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
void RadixSort32<TYPE>::deallocate( Data* data )
|
||||
{
|
||||
delete data->m_workBuffer0;
|
||||
delete data->m_workBuffer1;
|
||||
delete data->m_workBuffer2;
|
||||
delete data->m_workBuffer3;
|
||||
|
||||
for(int i=0; i<32/BITS_PER_PASS; i++)
|
||||
delete data->m_constBuffer[i];
|
||||
|
||||
Copy<TYPE>::deallocate( data->m_copyData );
|
||||
|
||||
delete data;
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
void RadixSort32<TYPE>::execute(Data* data, Buffer<u32>& inout, int n, int sortBits /* = 32 */ )
|
||||
{
|
||||
ADLASSERT( n%DATA_ALIGNMENT == 0 );
|
||||
ADLASSERT( n <= data->m_maxSize );
|
||||
// ADLASSERT( ELEMENTS_PER_WORK_ITEM == 4 );
|
||||
ADLASSERT( BITS_PER_PASS == 4 );
|
||||
ADLASSERT( WG_SIZE == 64 );
|
||||
ADLASSERT( (sortBits&0x3) == 0 );
|
||||
|
||||
Buffer<u32>* src = &inout;
|
||||
Buffer<u32>* dst = data->m_workBuffer0;
|
||||
Buffer<u32>* histogramBuffer = data->m_workBuffer1;
|
||||
|
||||
int nWGs = NUM_WGS;
|
||||
ConstData cdata;
|
||||
{
|
||||
int nBlocks = (n+ELEMENTS_PER_WORK_ITEM*WG_SIZE-1)/(ELEMENTS_PER_WORK_ITEM*WG_SIZE);
|
||||
|
||||
cdata.m_n = n;
|
||||
cdata.m_nWGs = NUM_WGS;
|
||||
cdata.m_startBit = 0;
|
||||
cdata.m_nBlocksPerWG = (nBlocks + cdata.m_nWGs - 1)/cdata.m_nWGs;
|
||||
|
||||
if( nBlocks < NUM_WGS )
|
||||
{
|
||||
cdata.m_nBlocksPerWG = 1;
|
||||
nWGs = nBlocks;
|
||||
}
|
||||
}
|
||||
|
||||
for(int ib=0; ib<sortBits; ib+=4)
|
||||
{
|
||||
cdata.m_startBit = ib;
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_streamCountKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( NUM_WGS*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
{// prefix scan group histogram
|
||||
BufferInfo bInfo[] = { BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_prefixScanKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( 128, 128 );
|
||||
}
|
||||
{// local sort and distribute
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( histogramBuffer, true ), BufferInfo( dst ) };
|
||||
Launcher launcher( data->m_device, data->m_sortAndScatterKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( nWGs*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
swap2( src, dst );
|
||||
}
|
||||
|
||||
if( src != &inout )
|
||||
{
|
||||
Copy<TYPE>::execute( data->m_copyData, (Buffer<float>&)inout, (Buffer<float>&)*src, n );
|
||||
}
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
void RadixSort32<TYPE>::execute(Data* data, Buffer<u32>& in, Buffer<u32>& out, int n, int sortBits /* = 32 */ )
|
||||
{
|
||||
ADLASSERT( n%DATA_ALIGNMENT == 0 );
|
||||
ADLASSERT( n <= data->m_maxSize );
|
||||
// ADLASSERT( ELEMENTS_PER_WORK_ITEM == 4 );
|
||||
ADLASSERT( BITS_PER_PASS == 4 );
|
||||
ADLASSERT( WG_SIZE == 64 );
|
||||
ADLASSERT( (sortBits&0x3) == 0 );
|
||||
|
||||
Buffer<u32>* src = ∈
|
||||
Buffer<u32>* dst = data->m_workBuffer0;
|
||||
Buffer<u32>* histogramBuffer = data->m_workBuffer1;
|
||||
|
||||
int nWGs = NUM_WGS;
|
||||
ConstData cdata;
|
||||
{
|
||||
int nBlocks = (n+ELEMENTS_PER_WORK_ITEM*WG_SIZE-1)/(ELEMENTS_PER_WORK_ITEM*WG_SIZE);
|
||||
cdata.m_n = n;
|
||||
cdata.m_nWGs = NUM_WGS;
|
||||
cdata.m_startBit = 0;
|
||||
cdata.m_nBlocksPerWG = (nBlocks + cdata.m_nWGs - 1)/cdata.m_nWGs;
|
||||
if( nBlocks < NUM_WGS )
|
||||
{
|
||||
cdata.m_nBlocksPerWG = 1;
|
||||
nWGs = nBlocks;
|
||||
}
|
||||
}
|
||||
|
||||
if( sortBits == 4 ) dst = &out;
|
||||
|
||||
for(int ib=0; ib<sortBits; ib+=4)
|
||||
{
|
||||
if( ib==4 )
|
||||
{
|
||||
dst = &out;
|
||||
}
|
||||
|
||||
cdata.m_startBit = ib;
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_streamCountKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( NUM_WGS*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
{// prefix scan group histogram
|
||||
BufferInfo bInfo[] = { BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_prefixScanKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( 128, 128 );
|
||||
}
|
||||
{// local sort and distribute
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( histogramBuffer, true ), BufferInfo( dst ) };
|
||||
Launcher launcher( data->m_device, data->m_sortAndScatterKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( nWGs*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
swap2( src, dst );
|
||||
}
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
void RadixSort32<TYPE>::execute(Data* data, Buffer<u32>& keysIn, Buffer<u32>& keysOut, Buffer<u32>& valuesIn, Buffer<u32>& valuesOut, int n, int sortBits /* = 32 */)
|
||||
{
|
||||
ADLASSERT( n%DATA_ALIGNMENT == 0 );
|
||||
ADLASSERT( n <= data->m_maxSize );
|
||||
// ADLASSERT( ELEMENTS_PER_WORK_ITEM == 4 );
|
||||
ADLASSERT( BITS_PER_PASS == 4 );
|
||||
ADLASSERT( WG_SIZE == 64 );
|
||||
ADLASSERT( (sortBits&0x3) == 0 );
|
||||
|
||||
Buffer<u32>* src = &keysIn;
|
||||
Buffer<u32>* srcVal = &valuesIn;
|
||||
Buffer<u32>* dst = data->m_workBuffer0;
|
||||
Buffer<u32>* dstVal = data->m_workBuffer2;
|
||||
Buffer<u32>* histogramBuffer = data->m_workBuffer1;
|
||||
|
||||
int nWGs = NUM_WGS;
|
||||
ConstData cdata;
|
||||
{
|
||||
int nBlocks = (n+ELEMENTS_PER_WORK_ITEM*WG_SIZE-1)/(ELEMENTS_PER_WORK_ITEM*WG_SIZE);
|
||||
cdata.m_n = n;
|
||||
cdata.m_nWGs = NUM_WGS;
|
||||
cdata.m_startBit = 0;
|
||||
cdata.m_nBlocksPerWG = (nBlocks + cdata.m_nWGs - 1)/cdata.m_nWGs;
|
||||
if( nBlocks < NUM_WGS )
|
||||
{
|
||||
cdata.m_nBlocksPerWG = 1;
|
||||
nWGs = nBlocks;
|
||||
}
|
||||
}
|
||||
|
||||
if( sortBits == 4 )
|
||||
{
|
||||
dst = &keysOut;
|
||||
dstVal = &valuesOut;
|
||||
}
|
||||
|
||||
for(int ib=0; ib<sortBits; ib+=4)
|
||||
{
|
||||
if( ib==4 )
|
||||
{
|
||||
dst = &keysOut;
|
||||
dstVal = &valuesOut;
|
||||
}
|
||||
|
||||
cdata.m_startBit = ib;
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_streamCountKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( NUM_WGS*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
{// prefix scan group histogram
|
||||
BufferInfo bInfo[] = { BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_prefixScanKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( 128, 128 );
|
||||
}
|
||||
{// local sort and distribute
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( srcVal, true ), BufferInfo( histogramBuffer, true ), BufferInfo( dst ), BufferInfo( dstVal ) };
|
||||
Launcher launcher( data->m_device, data->m_sortAndScatterKeyValueKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( nWGs*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
swap2( src, dst );
|
||||
swap2( srcVal, dstVal );
|
||||
}
|
||||
}
|
||||
|
||||
template<DeviceType TYPE>
|
||||
void RadixSort32<TYPE>::execute(Data* data, Buffer<SortData>& keyValuesInOut, int n, int sortBits /* = 32 */)
|
||||
{
|
||||
ADLASSERT( n%DATA_ALIGNMENT == 0 );
|
||||
ADLASSERT( n <= data->m_maxSize );
|
||||
// ADLASSERT( ELEMENTS_PER_WORK_ITEM == 4 );
|
||||
ADLASSERT( BITS_PER_PASS == 4 );
|
||||
ADLASSERT( WG_SIZE == 64 );
|
||||
ADLASSERT( (sortBits&0x3) == 0 );
|
||||
|
||||
Buffer<SortData>* src = &keyValuesInOut;
|
||||
Buffer<SortData>* dst = data->m_workBuffer3;
|
||||
|
||||
Buffer<u32>* histogramBuffer = data->m_workBuffer1;
|
||||
|
||||
int nWGs = NUM_WGS;
|
||||
ConstData cdata;
|
||||
{
|
||||
int nBlocks = (n+ELEMENTS_PER_WORK_ITEM*WG_SIZE-1)/(ELEMENTS_PER_WORK_ITEM*WG_SIZE);
|
||||
cdata.m_n = n;
|
||||
cdata.m_nWGs = NUM_WGS;
|
||||
cdata.m_startBit = 0;
|
||||
cdata.m_nBlocksPerWG = (nBlocks + cdata.m_nWGs - 1)/cdata.m_nWGs;
|
||||
if( nBlocks < NUM_WGS )
|
||||
{
|
||||
cdata.m_nBlocksPerWG = 1;
|
||||
nWGs = nBlocks;
|
||||
}
|
||||
}
|
||||
|
||||
int count=0;
|
||||
for(int ib=0; ib<sortBits; ib+=4)
|
||||
{
|
||||
cdata.m_startBit = ib;
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_streamCountSortDataKernel);
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( NUM_WGS*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
{// prefix scan group histogram
|
||||
BufferInfo bInfo[] = { BufferInfo( histogramBuffer ) };
|
||||
Launcher launcher( data->m_device, data->m_prefixScanKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( 128, 128 );
|
||||
}
|
||||
{// local sort and distribute
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( histogramBuffer, true ), BufferInfo( dst )};
|
||||
Launcher launcher( data->m_device, data->m_sortAndScatterSortDataKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[ib/4], cdata );
|
||||
launcher.launch1D( nWGs*WG_SIZE, WG_SIZE );
|
||||
}
|
||||
swap2( src, dst );
|
||||
count++;
|
||||
}
|
||||
|
||||
if (count&1)
|
||||
{
|
||||
ADLASSERT(0);//need to copy from workbuffer to keyValuesInOut
|
||||
|
||||
}
|
||||
}
|
||||
#undef PATH
|
||||
#undef RADIXSORT32_KERNEL0
|
||||
#undef RADIXSORT32_KERNEL1
|
||||
#undef RADIXSORT32_KERNEL2
|
||||
#undef RADIXSORT32_KERNEL3
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
template<>
|
||||
class RadixSort32<TYPE_HOST> : public RadixSort32Base
|
||||
{
|
||||
public:
|
||||
typedef Launcher::BufferInfo BufferInfo;
|
||||
|
||||
enum
|
||||
{
|
||||
BITS_PER_PASS = 8,
|
||||
NUM_TABLES = (1<<BITS_PER_PASS),
|
||||
};
|
||||
|
||||
struct Data
|
||||
{
|
||||
HostBuffer<u32>* m_workBuffer;
|
||||
};
|
||||
|
||||
static
|
||||
Data* allocate(const Device* device, int maxSize)
|
||||
{
|
||||
ADLASSERT( device->m_type == TYPE_HOST );
|
||||
|
||||
Data* data = new Data;
|
||||
data->m_workBuffer = new HostBuffer<u32>( device, maxSize );
|
||||
return data;
|
||||
}
|
||||
|
||||
static
|
||||
void deallocate(Data* data)
|
||||
{
|
||||
delete data->m_workBuffer;
|
||||
delete data;
|
||||
}
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<u32>& inout, int n, int sortBits = 32)
|
||||
{
|
||||
ADLASSERT( inout.getType() == TYPE_HOST );
|
||||
|
||||
int tables[NUM_TABLES];
|
||||
int counter[NUM_TABLES];
|
||||
|
||||
u32* src = inout.m_ptr;
|
||||
u32* dst = data->m_workBuffer->m_ptr;
|
||||
|
||||
for(int startBit=0; startBit<sortBits; startBit+=BITS_PER_PASS)
|
||||
{
|
||||
for(int i=0; i<NUM_TABLES; i++)
|
||||
{
|
||||
tables[i] = 0;
|
||||
}
|
||||
|
||||
for(int i=0; i<n; i++)
|
||||
{
|
||||
int tableIdx = (src[i] >> startBit) & (NUM_TABLES-1);
|
||||
tables[tableIdx]++;
|
||||
}
|
||||
|
||||
// prefix scan
|
||||
int sum = 0;
|
||||
for(int i=0; i<NUM_TABLES; i++)
|
||||
{
|
||||
int iData = tables[i];
|
||||
tables[i] = sum;
|
||||
sum += iData;
|
||||
counter[i] = 0;
|
||||
}
|
||||
|
||||
// distribute
|
||||
for(int i=0; i<n; i++)
|
||||
{
|
||||
int tableIdx = (src[i] >> startBit) & (NUM_TABLES-1);
|
||||
|
||||
dst[tables[tableIdx] + counter[tableIdx]] = src[i];
|
||||
counter[tableIdx] ++;
|
||||
}
|
||||
|
||||
swap2( src, dst );
|
||||
}
|
||||
|
||||
{
|
||||
if( src != inout.m_ptr )
|
||||
{
|
||||
memcpy( dst, src, sizeof(u32)*n );
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<u32>& keyInout, const Buffer<u32>& valueInout, int n, int sortBits = 32)
|
||||
{
|
||||
ADLASSERT( keyInout.getType() == TYPE_HOST );
|
||||
|
||||
int tables[NUM_TABLES];
|
||||
int counter[NUM_TABLES];
|
||||
|
||||
u32* src = keyInout.m_ptr;
|
||||
u32* dst = data->m_workBuffer->m_ptr;
|
||||
|
||||
HostBuffer<u32> bufVal(valueInout.m_device, valueInout.m_size);
|
||||
bufVal.write(valueInout.m_ptr, valueInout.m_size);
|
||||
|
||||
u32* srcVal = valueInout.m_ptr;
|
||||
u32* dstVal = bufVal.m_ptr;
|
||||
|
||||
for(int startBit=0; startBit<sortBits; startBit+=BITS_PER_PASS)
|
||||
{
|
||||
for(int i=0; i<NUM_TABLES; i++)
|
||||
{
|
||||
tables[i] = 0;
|
||||
}
|
||||
|
||||
for(int i=0; i<n; i++)
|
||||
{
|
||||
int tableIdx = (src[i] >> startBit) & (NUM_TABLES-1);
|
||||
tables[tableIdx]++;
|
||||
}
|
||||
|
||||
// prefix scan
|
||||
int sum = 0;
|
||||
for(int i=0; i<NUM_TABLES; i++)
|
||||
{
|
||||
int iData = tables[i];
|
||||
tables[i] = sum;
|
||||
sum += iData;
|
||||
counter[i] = 0;
|
||||
}
|
||||
|
||||
// distribute
|
||||
for(int i=0; i<n; i++)
|
||||
{
|
||||
int tableIdx = (src[i] >> startBit) & (NUM_TABLES-1);
|
||||
int newIdx = tables[tableIdx] + counter[tableIdx];
|
||||
dst[newIdx] = src[i];
|
||||
dstVal[newIdx] = srcVal[i];
|
||||
counter[tableIdx]++;
|
||||
}
|
||||
|
||||
swap2( src, dst );
|
||||
swap2( srcVal, dstVal );
|
||||
}
|
||||
|
||||
{
|
||||
if( src != keyInout.m_ptr )
|
||||
{
|
||||
memcpy( dst, src, sizeof(u32)*n );
|
||||
}
|
||||
|
||||
if( srcVal != valueInout.m_ptr )
|
||||
{
|
||||
memcpy( dstVal, srcVal, sizeof(u32)*n );
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,985 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
typedef uint u32;
|
||||
|
||||
#define GET_GROUP_IDX groupIdx.x
|
||||
#define GET_LOCAL_IDX localIdx.x
|
||||
#define GET_GLOBAL_IDX globalIdx.x
|
||||
#define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()
|
||||
#define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID
|
||||
#define AtomInc(x) InterlockedAdd(x, 1)
|
||||
#define AtomInc1(x, out) InterlockedAdd(x, 1, out)
|
||||
|
||||
#define min2 min
|
||||
#define max2 max
|
||||
|
||||
|
||||
cbuffer CB0 : register( b0 )
|
||||
{
|
||||
int m_startBit;
|
||||
int m_totalBlocks;
|
||||
int m_nWorkGroupsToExecute;
|
||||
int m_nBlocksPerGroup;
|
||||
|
||||
};
|
||||
|
||||
|
||||
typedef struct {
|
||||
unsigned int key;
|
||||
unsigned int value;
|
||||
} KeyValuePair;
|
||||
|
||||
|
||||
StructuredBuffer<u32> rHistogram : register(t0);
|
||||
|
||||
RWStructuredBuffer<KeyValuePair> dataToSort : register( u0 );
|
||||
RWStructuredBuffer<KeyValuePair> dataToSortOut : register( u1 );
|
||||
|
||||
|
||||
|
||||
#define WG_SIZE 128
|
||||
#define ELEMENTS_PER_WORK_ITEM 4
|
||||
#define BITS_PER_PASS 4
|
||||
#define NUM_BUCKET (1<<BITS_PER_PASS)
|
||||
|
||||
|
||||
groupshared u32 sorterSharedMemory[max(WG_SIZE*2*2, WG_SIZE*ELEMENTS_PER_WORK_ITEM*2)];
|
||||
groupshared u32 localHistogramToCarry[NUM_BUCKET];
|
||||
groupshared u32 localHistogram[NUM_BUCKET*2];
|
||||
groupshared u32 localHistogramMat[NUM_BUCKET*WG_SIZE];
|
||||
groupshared u32 localPrefixSum[NUM_BUCKET];
|
||||
|
||||
|
||||
|
||||
#define SET_LOCAL_SORT_DATA(idx, sortDataIn) sorterSharedMemory[2*(idx)+0] = sortDataIn.key; sorterSharedMemory[2*(idx)+1] = sortDataIn.value;
|
||||
#define GET_LOCAL_SORT_DATA(idx, sortDataOut) sortDataOut.key = sorterSharedMemory[2*(idx)+0]; sortDataOut.value = sorterSharedMemory[2*(idx)+1];
|
||||
|
||||
|
||||
|
||||
uint4 prefixScanVector( uint4 data )
|
||||
{
|
||||
data.y += data.x;
|
||||
data.w += data.z;
|
||||
data.z += data.y;
|
||||
data.w += data.y;
|
||||
return data;
|
||||
}
|
||||
|
||||
uint prefixScanVectorEx( inout uint4 data )
|
||||
{
|
||||
uint4 backup = data;
|
||||
data.y += data.x;
|
||||
data.w += data.z;
|
||||
data.z += data.y;
|
||||
data.w += data.y;
|
||||
uint sum = data.w;
|
||||
data -= backup;
|
||||
return sum;
|
||||
}
|
||||
|
||||
uint localPrefixScan128( uint pData, uint lIdx, inout uint totalSum )
|
||||
{
|
||||
{ // Set data
|
||||
sorterSharedMemory[lIdx] = 0;
|
||||
sorterSharedMemory[lIdx+WG_SIZE] = pData;
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // Prefix sum
|
||||
int idx = 2*lIdx + (WG_SIZE+1);
|
||||
if( lIdx < 64 )
|
||||
{
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-1];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-2];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-4];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-8];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-16];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-32];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-64];
|
||||
}
|
||||
if( lIdx < 64 ) sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
totalSum = sorterSharedMemory[WG_SIZE*2-1];
|
||||
return sorterSharedMemory[lIdx+127];
|
||||
}
|
||||
|
||||
void localPrefixScan128Dual( uint pData0, uint pData1, uint lIdx,
|
||||
inout uint rank0, inout uint rank1,
|
||||
inout uint totalSum0, inout uint totalSum1 )
|
||||
{
|
||||
{ // Set data
|
||||
sorterSharedMemory[lIdx] = 0;
|
||||
sorterSharedMemory[lIdx+WG_SIZE] = pData0;
|
||||
sorterSharedMemory[2*WG_SIZE+lIdx] = 0;
|
||||
sorterSharedMemory[2*WG_SIZE+lIdx+WG_SIZE] = pData1;
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
// if( lIdx < 128 ) // todo. assert wg size is 128
|
||||
{ // Prefix sum
|
||||
int blockIdx = lIdx/64;
|
||||
int groupIdx = lIdx%64;
|
||||
int idx = 2*groupIdx + (WG_SIZE+1) + (2*WG_SIZE)*blockIdx;
|
||||
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-1];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-2];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-4];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-8];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-16];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-32];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-64];
|
||||
|
||||
sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
totalSum0 = sorterSharedMemory[WG_SIZE*2-1];
|
||||
rank0 = sorterSharedMemory[lIdx+127];
|
||||
totalSum1 = sorterSharedMemory[2*WG_SIZE+WG_SIZE*2-1];
|
||||
rank1 = sorterSharedMemory[2*WG_SIZE+lIdx+127];
|
||||
}
|
||||
|
||||
uint4 localPrefixSum128V( uint4 pData, uint lIdx, inout uint totalSum )
|
||||
{
|
||||
{ // Set data
|
||||
sorterSharedMemory[lIdx] = 0;
|
||||
sorterSharedMemory[lIdx+WG_SIZE] = prefixScanVectorEx( pData );
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // Prefix sum
|
||||
int idx = 2*lIdx + (WG_SIZE+1);
|
||||
if( lIdx < 64 )
|
||||
{
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-1];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-2];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-4];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-8];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-16];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-32];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-64];
|
||||
|
||||
sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
totalSum = sorterSharedMemory[WG_SIZE*2-1];
|
||||
uint addValue = sorterSharedMemory[lIdx+127];
|
||||
return pData + uint4(addValue, addValue, addValue, addValue);
|
||||
}
|
||||
|
||||
void localPrefixSum128Dual( uint4 pData0, uint4 pData1, uint lIdx,
|
||||
inout uint4 dataOut0, inout uint4 dataOut1,
|
||||
inout uint totalSum0, inout uint totalSum1 )
|
||||
{
|
||||
/*
|
||||
dataOut0 = localPrefixSum128V( pData0, lIdx, totalSum0 );
|
||||
GROUP_LDS_BARRIER;
|
||||
dataOut1 = localPrefixSum128V( pData1, lIdx, totalSum1 );
|
||||
return;
|
||||
*/
|
||||
|
||||
uint4 backup0 = pData0;
|
||||
uint4 backup1 = pData1;
|
||||
|
||||
{ // Prefix sum in a vector
|
||||
pData0 = prefixScanVector( pData0 );
|
||||
pData1 = prefixScanVector( pData1 );
|
||||
}
|
||||
|
||||
{ // Set data
|
||||
sorterSharedMemory[lIdx] = 0;
|
||||
sorterSharedMemory[lIdx+WG_SIZE] = pData0.w;
|
||||
sorterSharedMemory[2*WG_SIZE+lIdx] = 0;
|
||||
sorterSharedMemory[2*WG_SIZE+lIdx+WG_SIZE] = pData1.w;
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
// if( lIdx < 128 ) // todo. assert wg size is 128
|
||||
{ // Prefix sum
|
||||
int blockIdx = lIdx/64;
|
||||
int groupIdx = lIdx%64;
|
||||
int idx = 2*groupIdx + (WG_SIZE+1) + (2*WG_SIZE)*blockIdx;
|
||||
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-1];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-2];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-4];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-8];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-16];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-32];
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-64];
|
||||
|
||||
sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
totalSum0 = sorterSharedMemory[WG_SIZE*2-1];
|
||||
{
|
||||
uint addValue = sorterSharedMemory[lIdx+127];
|
||||
dataOut0 = pData0 + uint4(addValue, addValue, addValue, addValue) - backup0;
|
||||
}
|
||||
|
||||
totalSum1 = sorterSharedMemory[2*WG_SIZE+WG_SIZE*2-1];
|
||||
{
|
||||
uint addValue = sorterSharedMemory[2*WG_SIZE+lIdx+127];
|
||||
dataOut1 = pData1 + uint4(addValue, addValue, addValue, addValue) - backup1;
|
||||
}
|
||||
}
|
||||
|
||||
uint4 extractKeys(uint4 data, uint targetKey)
|
||||
{
|
||||
uint4 key;
|
||||
key.x = data.x == targetKey ? 1:0;
|
||||
key.y = data.y == targetKey ? 1:0;
|
||||
key.z = data.z == targetKey ? 1:0;
|
||||
key.w = data.w == targetKey ? 1:0;
|
||||
return key;
|
||||
}
|
||||
|
||||
uint4 extractKeysByBits(uint4 data, uint targetKey)
|
||||
{
|
||||
uint4 key;
|
||||
uint mask = 1<<targetKey;
|
||||
key.x = (data.x & mask) >> targetKey;
|
||||
key.y = (data.y & mask) >> targetKey;
|
||||
key.z = (data.z & mask) >> targetKey;
|
||||
key.w = (data.w & mask) >> targetKey;
|
||||
return key;
|
||||
}
|
||||
|
||||
uint packKeys(uint lower, uint upper)
|
||||
{
|
||||
return lower|(upper<<16);
|
||||
}
|
||||
|
||||
uint4 packKeys(uint4 lower, uint4 upper)
|
||||
{
|
||||
return uint4( lower.x|(upper.x<<16), lower.y|(upper.y<<16), lower.z|(upper.z<<16), lower.w|(upper.w<<16) );
|
||||
}
|
||||
|
||||
uint extractLower( uint data )
|
||||
{
|
||||
return data&0xffff;
|
||||
}
|
||||
|
||||
uint extractUpper( uint data )
|
||||
{
|
||||
return (data>>16)&0xffff;
|
||||
}
|
||||
|
||||
uint4 extractLower( uint4 data )
|
||||
{
|
||||
return uint4( data.x&0xffff, data.y&0xffff, data.z&0xffff, data.w&0xffff );
|
||||
}
|
||||
|
||||
uint4 extractUpper( uint4 data )
|
||||
{
|
||||
return uint4( (data.x>>16)&0xffff, (data.y>>16)&0xffff, (data.z>>16)&0xffff, (data.w>>16)&0xffff );
|
||||
}
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void SortAndScatterKernel( DEFAULT_ARGS )
|
||||
{
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
|
||||
if( lIdx < (NUM_BUCKET) )
|
||||
{
|
||||
localHistogramToCarry[lIdx] = rHistogram[lIdx*m_nWorkGroupsToExecute + wgIdx];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
for(uint igroup=wgIdx*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx+1)*m_nBlocksPerGroup); igroup++)
|
||||
{
|
||||
u32 myHistogram;
|
||||
if( lIdx < (NUM_BUCKET) )
|
||||
{
|
||||
localPrefixSum[lIdx] = 0.f;
|
||||
}
|
||||
|
||||
u32 newOffset[4];
|
||||
KeyValuePair myData[4];
|
||||
{ // read data
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
uint startAddress = igroup*numLocalElements + lIdx*4;
|
||||
|
||||
myData[0] = dataToSort[startAddress+0];
|
||||
myData[1] = dataToSort[startAddress+1];
|
||||
myData[2] = dataToSort[startAddress+2];
|
||||
myData[3] = dataToSort[startAddress+3];
|
||||
|
||||
newOffset[0] = newOffset[1] = newOffset[2] = newOffset[3] = 0;
|
||||
}
|
||||
|
||||
int localOffset = 0;
|
||||
uint4 b = uint4((myData[0].key>>m_startBit) & 0xf, (myData[1].key>>m_startBit) & 0xf, (myData[2].key>>m_startBit) & 0xf, (myData[3].key>>m_startBit) & 0xf);
|
||||
for(uint targetKey=0; targetKey<(NUM_BUCKET); targetKey+=4)
|
||||
{
|
||||
uint4 key[4];
|
||||
uint keySet[2];
|
||||
{ // pack 4
|
||||
uint4 scannedKey[4];
|
||||
key[0] = scannedKey[0] = extractKeys( b, targetKey+0 );
|
||||
key[1] = scannedKey[1] = extractKeys( b, targetKey+1 );
|
||||
key[2] = scannedKey[2] = extractKeys( b, targetKey+2 );
|
||||
key[3] = scannedKey[3] = extractKeys( b, targetKey+3 );
|
||||
{
|
||||
uint s[4];
|
||||
s[0] = prefixScanVectorEx( scannedKey[0] );
|
||||
s[1] = prefixScanVectorEx( scannedKey[1] );
|
||||
s[2] = prefixScanVectorEx( scannedKey[2] );
|
||||
s[3] = prefixScanVectorEx( scannedKey[3] );
|
||||
keySet[0] = packKeys( s[0], s[1] );
|
||||
keySet[1] = packKeys( s[2], s[3] );
|
||||
}
|
||||
}
|
||||
|
||||
uint dstAddressBase[4];
|
||||
{
|
||||
|
||||
uint totalSumPacked[2];
|
||||
uint dstAddressPacked[2];
|
||||
|
||||
localPrefixScan128Dual( keySet[0], keySet[1], lIdx, dstAddressPacked[0], dstAddressPacked[1], totalSumPacked[0], totalSumPacked[1] );
|
||||
|
||||
dstAddressBase[0] = extractLower( dstAddressPacked[0] );
|
||||
dstAddressBase[1] = extractUpper( dstAddressPacked[0] );
|
||||
dstAddressBase[2] = extractLower( dstAddressPacked[1] );
|
||||
dstAddressBase[3] = extractUpper( dstAddressPacked[1] );
|
||||
|
||||
uint4 histogram;
|
||||
histogram.x = extractLower(totalSumPacked[0]);
|
||||
histogram.y = extractUpper(totalSumPacked[0]);
|
||||
histogram.z = extractLower(totalSumPacked[1]);
|
||||
histogram.w = extractUpper(totalSumPacked[1]);
|
||||
|
||||
if( lIdx == targetKey + 0 ) myHistogram = histogram.x;
|
||||
else if( lIdx == targetKey + 1 ) myHistogram = histogram.y;
|
||||
else if( lIdx == targetKey + 2 ) myHistogram = histogram.z;
|
||||
else if( lIdx == targetKey + 3 ) myHistogram = histogram.w;
|
||||
|
||||
uint histogramSum = prefixScanVectorEx( histogram );
|
||||
|
||||
if( lIdx == targetKey + 0 ) localPrefixSum[targetKey+0] = localOffset+histogram.x;
|
||||
else if( lIdx == targetKey + 1 ) localPrefixSum[targetKey+1] = localOffset+histogram.y;
|
||||
else if( lIdx == targetKey + 2 ) localPrefixSum[targetKey+2] = localOffset+histogram.z;
|
||||
else if( lIdx == targetKey + 3 ) localPrefixSum[targetKey+3] = localOffset+histogram.w;
|
||||
|
||||
localOffset += histogramSum;
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
|
||||
for(int ie=0; ie<4; ie++)
|
||||
{
|
||||
uint4 scannedKey = key[ie];
|
||||
prefixScanVectorEx( scannedKey );
|
||||
|
||||
uint offset = localPrefixSum[targetKey + ie] + dstAddressBase[ie];
|
||||
uint4 dstAddress = uint4( offset, offset, offset, offset ) + scannedKey;
|
||||
|
||||
newOffset[0] += dstAddress.x*key[ie].x;
|
||||
newOffset[1] += dstAddress.y*key[ie].y;
|
||||
newOffset[2] += dstAddress.z*key[ie].z;
|
||||
newOffset[3] += dstAddress.w*key[ie].w;
|
||||
}
|
||||
}
|
||||
|
||||
{ // local scatter
|
||||
SET_LOCAL_SORT_DATA(newOffset[0], myData[0]);
|
||||
SET_LOCAL_SORT_DATA(newOffset[1], myData[1]);
|
||||
SET_LOCAL_SORT_DATA(newOffset[2], myData[2]);
|
||||
SET_LOCAL_SORT_DATA(newOffset[3], myData[3]);
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // write data
|
||||
for(int i=0; i<ELEMENTS_PER_WORK_ITEM; i++)
|
||||
{
|
||||
int dataIdx = 4*lIdx+i;
|
||||
KeyValuePair localData; GET_LOCAL_SORT_DATA( dataIdx, localData );
|
||||
int binIdx = (localData.key >> m_startBit) & 0xf;
|
||||
int groupOffset = localHistogramToCarry[binIdx];
|
||||
int myIdx = dataIdx - localPrefixSum[binIdx];
|
||||
|
||||
dataToSortOut[ groupOffset + myIdx ] = localData;
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
if( lIdx < NUM_BUCKET )
|
||||
{
|
||||
localHistogramToCarry[lIdx] += myHistogram;
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void SortAndScatterKernel1( DEFAULT_ARGS )
|
||||
{
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
|
||||
if( lIdx < (NUM_BUCKET) )
|
||||
{
|
||||
localHistogramToCarry[lIdx] = rHistogram[lIdx*m_nWorkGroupsToExecute + wgIdx.x];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
for(uint igroup=wgIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx.x+1)*m_nBlocksPerGroup); igroup++)
|
||||
{
|
||||
u32 myHistogram;
|
||||
|
||||
KeyValuePair myData[4];
|
||||
uint startAddrBlock;
|
||||
{ // read data
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
startAddrBlock = lIdx*4;
|
||||
uint startAddress = igroup*numLocalElements + startAddrBlock;
|
||||
|
||||
myData[0] = dataToSort[startAddress+0];
|
||||
myData[1] = dataToSort[startAddress+1];
|
||||
myData[2] = dataToSort[startAddress+2];
|
||||
myData[3] = dataToSort[startAddress+3];
|
||||
}
|
||||
|
||||
// local sort
|
||||
for(int ib=m_startBit; ib<m_startBit+BITS_PER_PASS; ib++)
|
||||
{
|
||||
uint4 keys = uint4(~(myData[0].key>>ib) & 0x1, ~(myData[1].key>>ib) & 0x1, ~(myData[2].key>>ib) & 0x1, ~(myData[3].key>>ib) & 0x1);
|
||||
uint total;
|
||||
uint4 rankOfP = localPrefixSum128V( keys, lIdx, total );
|
||||
uint4 rankOfN = uint4(startAddrBlock, startAddrBlock+1, startAddrBlock+2, startAddrBlock+3) - rankOfP + uint4( total, total, total, total );
|
||||
|
||||
uint4 myAddr = (keys==uint4(1,1,1,1))? rankOfP: rankOfN;
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
SET_LOCAL_SORT_DATA( myAddr.x, myData[0] );
|
||||
SET_LOCAL_SORT_DATA( myAddr.y, myData[1] );
|
||||
SET_LOCAL_SORT_DATA( myAddr.z, myData[2] );
|
||||
SET_LOCAL_SORT_DATA( myAddr.w, myData[3] );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+0, myData[0] );
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+1, myData[1] );
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+2, myData[2] );
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+3, myData[3] );
|
||||
}
|
||||
|
||||
{// create histogram -> prefix sum
|
||||
if( lIdx < NUM_BUCKET )
|
||||
{
|
||||
localHistogram[lIdx] = 0;
|
||||
localHistogram[NUM_BUCKET+lIdx] = 0;
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
uint4 keys = uint4((myData[0].key>>m_startBit) & 0xf, (myData[1].key>>m_startBit) & 0xf, (myData[2].key>>m_startBit) & 0xf, (myData[3].key>>m_startBit) & 0xf);
|
||||
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.x], 1 );
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.y], 1 );
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.z], 1 );
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.w], 1 );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
uint hIdx = NUM_BUCKET+lIdx;
|
||||
if( lIdx < NUM_BUCKET )
|
||||
{
|
||||
myHistogram = localHistogram[hIdx];
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
if( lIdx < NUM_BUCKET )
|
||||
{
|
||||
localHistogram[hIdx] = localHistogram[hIdx-1];
|
||||
|
||||
localHistogram[hIdx] += localHistogram[hIdx-1];
|
||||
localHistogram[hIdx] += localHistogram[hIdx-2];
|
||||
localHistogram[hIdx] += localHistogram[hIdx-4];
|
||||
localHistogram[hIdx] += localHistogram[hIdx-8];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
}
|
||||
/*
|
||||
{// write back
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
startAddrBlock = lIdx*4;
|
||||
uint startAddress = igroup*numLocalElements + startAddrBlock;
|
||||
|
||||
for(int ie=0; ie<ELEMENTS_PER_WORK_ITEM; ie++)
|
||||
{
|
||||
dataToSortOut[ startAddress+ie ] = myData[ie];
|
||||
}
|
||||
}
|
||||
*/
|
||||
{
|
||||
for(int ie=0; ie<ELEMENTS_PER_WORK_ITEM; ie++)
|
||||
{
|
||||
int dataIdx = startAddrBlock+ie;
|
||||
int binIdx = (myData[ie].key>>m_startBit)&0xf;
|
||||
int groupOffset = localHistogramToCarry[binIdx];
|
||||
int myIdx = dataIdx - localHistogram[NUM_BUCKET+binIdx];
|
||||
dataToSortOut[ groupOffset + myIdx ] = myData[ie];
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
if( lIdx < NUM_BUCKET )
|
||||
{
|
||||
localHistogramToCarry[lIdx] += myHistogram;
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void SortAndScatterKernel1( uint3 gIdx : SV_GroupID, uint3 lIdx : SV_GroupThreadID )
|
||||
{
|
||||
if( lIdx.x < (NUM_BUCKET) )
|
||||
{
|
||||
localHistogramToCarry[lIdx.x] = rHistogram[lIdx.x*m_nWorkGroupsToExecute + gIdx.x];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
for(uint igroup=gIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(gIdx.x+1)*m_nBlocksPerGroup); igroup++)
|
||||
{
|
||||
u32 myHistogram;
|
||||
|
||||
KeyValuePair myData[4];
|
||||
uint startAddrBlock;
|
||||
{ // read data
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
startAddrBlock = lIdx.x*4;
|
||||
uint startAddress = igroup*numLocalElements + startAddrBlock;
|
||||
|
||||
myData[0] = dataToSort[startAddress+0];
|
||||
myData[1] = dataToSort[startAddress+1];
|
||||
myData[2] = dataToSort[startAddress+2];
|
||||
myData[3] = dataToSort[startAddress+3];
|
||||
}
|
||||
|
||||
for(int ib=m_startBit; ib<m_startBit+BITS_PER_PASS; ib++)
|
||||
{
|
||||
uint4 keys = uint4(~(myData[0].key>>ib) & 0x1, ~(myData[1].key>>ib) & 0x1, ~(myData[2].key>>ib) & 0x1, ~(myData[3].key>>ib) & 0x1);
|
||||
uint total;
|
||||
uint4 rankOfP = localPrefixSum128V( keys, lIdx.x, total );
|
||||
uint4 rankOfN = uint4(startAddrBlock, startAddrBlock+1, startAddrBlock+2, startAddrBlock+3) - rankOfP + uint4( total, total, total, total );
|
||||
|
||||
uint4 myAddr = (keys==uint4(1,1,1,1))? rankOfP: rankOfN;
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
SET_LOCAL_SORT_DATA( myAddr.x, myData[0] );
|
||||
SET_LOCAL_SORT_DATA( myAddr.y, myData[1] );
|
||||
SET_LOCAL_SORT_DATA( myAddr.z, myData[2] );
|
||||
SET_LOCAL_SORT_DATA( myAddr.w, myData[3] );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+0, myData[0] );
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+1, myData[1] );
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+2, myData[2] );
|
||||
GET_LOCAL_SORT_DATA( startAddrBlock+3, myData[3] );
|
||||
}
|
||||
|
||||
{// create histogram -> prefix sum
|
||||
if( lIdx.x < NUM_BUCKET )
|
||||
{
|
||||
localHistogram[lIdx.x] = 0;
|
||||
localHistogram[NUM_BUCKET+lIdx.x] = 0;
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
uint4 keys = uint4((myData[0].key>>m_startBit) & 0xf, (myData[1].key>>m_startBit) & 0xf, (myData[2].key>>m_startBit) & 0xf, (myData[3].key>>m_startBit) & 0xf);
|
||||
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.x], 1 );
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.y], 1 );
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.z], 1 );
|
||||
InterlockedAdd( localHistogram[NUM_BUCKET+keys.w], 1 );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
uint hIdx = NUM_BUCKET+lIdx.x;
|
||||
if( lIdx.x < NUM_BUCKET )
|
||||
{
|
||||
myHistogram = localHistogram[hIdx];
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
|
||||
if( lIdx.x < NUM_BUCKET )
|
||||
{
|
||||
localHistogram[hIdx] = localHistogram[hIdx-1];
|
||||
|
||||
localHistogram[hIdx] += localHistogram[hIdx-1];
|
||||
localHistogram[hIdx] += localHistogram[hIdx-2];
|
||||
localHistogram[hIdx] += localHistogram[hIdx-4];
|
||||
localHistogram[hIdx] += localHistogram[hIdx-8];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
}
|
||||
{// write back
|
||||
for(int ie=0; ie<ELEMENTS_PER_WORK_ITEM; ie++)
|
||||
{
|
||||
int dataIdx = startAddrBlock+ie;
|
||||
int binIdx = (myData[ie].key>>m_startBit)&0xf;
|
||||
int groupOffset = localHistogramToCarry[binIdx];
|
||||
int myIdx = dataIdx - localHistogram[NUM_BUCKET+binIdx];
|
||||
|
||||
dataToSortOut[ groupOffset + myIdx ] = myData[ie];
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
if( lIdx.x < NUM_BUCKET )
|
||||
{
|
||||
localHistogramToCarry[lIdx.x] += myHistogram;
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
StructuredBuffer<KeyValuePair> dataToSort1 : register( t0 );
|
||||
RWStructuredBuffer<u32> wHistogram1 : register(u0);
|
||||
|
||||
#define MY_HISTOGRAM(idx) localHistogramMat[(idx)*WG_SIZE+lIdx.x]
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void StreamCountKernel( DEFAULT_ARGS )
|
||||
{
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
|
||||
int myHistogram[NUM_BUCKET];
|
||||
|
||||
for(int i=0; i<NUM_BUCKET; i++)
|
||||
{
|
||||
MY_HISTOGRAM(i) = 0;
|
||||
}
|
||||
|
||||
for(uint igroup=wgIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx.x+1)*m_nBlocksPerGroup); igroup++)
|
||||
{
|
||||
uint localKeys[4];
|
||||
{ // read data
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
|
||||
uint4 localAddress = uint4(lIdx, lIdx, lIdx, lIdx)*4+uint4(0,1,2,3);
|
||||
uint4 globalAddress = uint4(igroup,igroup,igroup,igroup)*numLocalElements + localAddress;
|
||||
|
||||
KeyValuePair localData0 = dataToSort1[globalAddress.x];
|
||||
KeyValuePair localData1 = dataToSort1[globalAddress.y];
|
||||
KeyValuePair localData2 = dataToSort1[globalAddress.z];
|
||||
KeyValuePair localData3 = dataToSort1[globalAddress.w];
|
||||
|
||||
localKeys[0] = (localData0.key >> m_startBit) & 0xf;
|
||||
localKeys[1] = (localData1.key >> m_startBit) & 0xf;
|
||||
localKeys[2] = (localData2.key >> m_startBit) & 0xf;
|
||||
localKeys[3] = (localData3.key >> m_startBit) & 0xf;
|
||||
}
|
||||
|
||||
MY_HISTOGRAM( localKeys[0] )++;
|
||||
MY_HISTOGRAM( localKeys[1] )++;
|
||||
MY_HISTOGRAM( localKeys[2] )++;
|
||||
MY_HISTOGRAM( localKeys[3] )++;
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // reduce to 1
|
||||
if( lIdx < 64 )//WG_SIZE/2 )
|
||||
{
|
||||
for(int i=0; i<NUM_BUCKET/2; i++)
|
||||
{
|
||||
int idx = lIdx;
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];
|
||||
}
|
||||
}
|
||||
else if( lIdx < 128 )
|
||||
{
|
||||
for(int i=NUM_BUCKET/2; i<NUM_BUCKET; i++)
|
||||
{
|
||||
int idx = lIdx-64;
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // write data
|
||||
if( lIdx < NUM_BUCKET )
|
||||
{
|
||||
wHistogram1[ lIdx*m_nWorkGroupsToExecute + wgIdx.x ] = localHistogramMat[ lIdx*WG_SIZE+0 ];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void StreamCountKernel( uint3 gIdx : SV_GroupID, uint3 lIdx : SV_GroupThreadID )
|
||||
{
|
||||
int myHistogram[NUM_BUCKET];
|
||||
|
||||
for(int i=0; i<NUM_BUCKET; i++)
|
||||
{
|
||||
myHistogram[i] = 0;
|
||||
}
|
||||
|
||||
for(uint igroup=gIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(gIdx.x+1)*m_nBlocksPerGroup); igroup++)
|
||||
{
|
||||
uint localKeys[4];
|
||||
{ // read data
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
|
||||
uint4 localAddress = uint4(lIdx.x, lIdx.x, lIdx.x, lIdx.x)*4+uint4(0,1,2,3);
|
||||
uint4 globalAddress = uint4(igroup,igroup,igroup,igroup)*numLocalElements + localAddress;
|
||||
|
||||
KeyValuePair localData0 = dataToSort1[globalAddress.x];
|
||||
KeyValuePair localData1 = dataToSort1[globalAddress.y];
|
||||
KeyValuePair localData2 = dataToSort1[globalAddress.z];
|
||||
KeyValuePair localData3 = dataToSort1[globalAddress.w];
|
||||
|
||||
localKeys[0] = (localData0.key >> m_startBit) & 0xf;
|
||||
localKeys[1] = (localData1.key >> m_startBit) & 0xf;
|
||||
localKeys[2] = (localData2.key >> m_startBit) & 0xf;
|
||||
localKeys[3] = (localData3.key >> m_startBit) & 0xf;
|
||||
}
|
||||
|
||||
myHistogram[ localKeys[0] ]++;
|
||||
myHistogram[ localKeys[1] ]++;
|
||||
myHistogram[ localKeys[2] ]++;
|
||||
myHistogram[ localKeys[3] ]++;
|
||||
}
|
||||
|
||||
{ // move to shared
|
||||
for(int i=0; i<NUM_BUCKET; i++)
|
||||
{
|
||||
localHistogramMat[i*WG_SIZE+lIdx.x] = myHistogram[i];
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // reduce to 1
|
||||
if( lIdx.x < 64 )//WG_SIZE/2 )
|
||||
{
|
||||
for(int i=0; i<NUM_BUCKET/2; i++)
|
||||
{
|
||||
int idx = lIdx.x;
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];
|
||||
}
|
||||
}
|
||||
else if( lIdx.x < 128 )
|
||||
{
|
||||
for(int i=NUM_BUCKET/2; i<NUM_BUCKET; i++)
|
||||
{
|
||||
int idx = lIdx.x-64;
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];
|
||||
localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // write data
|
||||
if( lIdx.x < NUM_BUCKET )
|
||||
{
|
||||
wHistogram1[ lIdx.x*m_nWorkGroupsToExecute + gIdx.x ] = localHistogramMat[ lIdx.x*WG_SIZE+0 ];
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
// for MAX_WG_SIZE 20
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void PrefixScanKernel( uint3 gIdx : SV_GroupID, uint3 lIdx : SV_GroupThreadID )
|
||||
{
|
||||
uint4 myData = uint4(0,0,0,0);
|
||||
if( 4*lIdx.x+0 < NUM_BUCKET*m_nWorkGroupsToExecute )
|
||||
myData.x = wHistogram1[4*lIdx.x+0];
|
||||
if( 4*lIdx.x+1 < NUM_BUCKET*m_nWorkGroupsToExecute )
|
||||
myData.y = wHistogram1[4*lIdx.x+1];
|
||||
if( 4*lIdx.x+2 < NUM_BUCKET*m_nWorkGroupsToExecute )
|
||||
myData.z = wHistogram1[4*lIdx.x+2];
|
||||
if( 4*lIdx.x+3 < NUM_BUCKET*m_nWorkGroupsToExecute )
|
||||
myData.w = wHistogram1[4*lIdx.x+3];
|
||||
|
||||
uint totalSum;
|
||||
|
||||
uint4 scanned = localPrefixSum128V( myData, lIdx.x, totalSum );
|
||||
|
||||
wHistogram1[4*lIdx.x+0] = scanned.x;
|
||||
wHistogram1[4*lIdx.x+1] = scanned.y;
|
||||
wHistogram1[4*lIdx.x+2] = scanned.z;
|
||||
wHistogram1[4*lIdx.x+3] = scanned.w;
|
||||
}
|
||||
*/
|
||||
|
||||
// for MAX_WG_SIZE 80
|
||||
// can hold up to WG_SIZE*12 (128*12 > 80*16 )
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void PrefixScanKernel( DEFAULT_ARGS )
|
||||
{
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
|
||||
uint data[12] = {0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
for(int i=0; i<12; i++)
|
||||
{
|
||||
if( int(12*lIdx+i) < NUM_BUCKET*m_nWorkGroupsToExecute )
|
||||
data[i] = wHistogram1[12*lIdx+i];
|
||||
}
|
||||
|
||||
uint4 myData = uint4(0,0,0,0);
|
||||
myData.x = data[0] + data[1];
|
||||
myData.y = data[2] + data[3];
|
||||
myData.z = data[4] + data[5];
|
||||
myData.w = data[6] + data[7];
|
||||
|
||||
|
||||
uint totalSum;
|
||||
uint4 scanned = localPrefixSum128V( myData, lIdx, totalSum );
|
||||
|
||||
data[11] = scanned.w + data[9] + data[10];
|
||||
data[10] = scanned.w + data[9];
|
||||
data[9] = scanned.w;
|
||||
data[8] = scanned.z + data[6] + data[7];
|
||||
data[7] = scanned.z + data[6];
|
||||
data[6] = scanned.z;
|
||||
data[5] = scanned.y + data[3] + data[4];
|
||||
data[4] = scanned.y + data[3];
|
||||
data[3] = scanned.y;
|
||||
data[2] = scanned.x + data[0] + data[1];
|
||||
data[1] = scanned.x + data[0];
|
||||
data[0] = scanned.x;
|
||||
|
||||
for(int i=0; i<12; i++)
|
||||
{
|
||||
wHistogram1[12*lIdx+i] = data[i];
|
||||
}
|
||||
}
|
||||
/*
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void PrefixScanKernel( DEFAULT_ARGS )
|
||||
{
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
|
||||
uint data[8] = {0,0,0,0,0,0,0,0};
|
||||
for(int i=0; i<8; i++)
|
||||
{
|
||||
if( int(8*lIdx+i) < NUM_BUCKET*m_nWorkGroupsToExecute )
|
||||
data[i] = wHistogram1[8*lIdx+i];
|
||||
}
|
||||
|
||||
uint4 myData = uint4(0,0,0,0);
|
||||
myData.x = data[0] + data[1];
|
||||
myData.y = data[2] + data[3];
|
||||
myData.z = data[4] + data[5];
|
||||
myData.w = data[6] + data[7];
|
||||
|
||||
|
||||
uint totalSum;
|
||||
uint4 scanned = localPrefixSum128V( myData, lIdx, totalSum );
|
||||
|
||||
data[7] = scanned.w + data[6];
|
||||
data[6] = scanned.w;// + data[5];
|
||||
data[5] = scanned.z + data[4];
|
||||
data[4] = scanned.z;// + data[3];
|
||||
data[3] = scanned.y + data[2];
|
||||
data[2] = scanned.y;// + data[1];
|
||||
data[1] = scanned.x + data[0];
|
||||
data[0] = scanned.x;
|
||||
|
||||
for(int i=0; i<8; i++)
|
||||
{
|
||||
wHistogram1[8*lIdx+i] = data[i];
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void CopyKernel( DEFAULT_ARGS )
|
||||
{
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
|
||||
for(uint igroup=wgIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx.x+1)*m_nBlocksPerGroup); igroup++)
|
||||
{
|
||||
KeyValuePair myData[4];
|
||||
uint startAddrBlock;
|
||||
{ // read data
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
startAddrBlock = lIdx*4;
|
||||
uint startAddress = igroup*numLocalElements + startAddrBlock;
|
||||
|
||||
myData[0] = dataToSort[startAddress+0];
|
||||
myData[1] = dataToSort[startAddress+1];
|
||||
myData[2] = dataToSort[startAddress+2];
|
||||
myData[3] = dataToSort[startAddress+3];
|
||||
}
|
||||
|
||||
{
|
||||
int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;
|
||||
uint startAddress = igroup*numLocalElements + startAddrBlock;
|
||||
|
||||
dataToSortOut[startAddress+0] = myData[0];
|
||||
dataToSortOut[startAddress+1] = myData[1];
|
||||
dataToSortOut[startAddress+2] = myData[2];
|
||||
dataToSortOut[startAddress+3] = myData[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,987 @@
|
||||
static const char* radixSortAdvancedKernelsDX11= \
|
||||
"/*\n"
|
||||
" 2011 Takahiro Harada\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
"typedef uint u32;\n"
|
||||
"\n"
|
||||
"#define GET_GROUP_IDX groupIdx.x\n"
|
||||
"#define GET_LOCAL_IDX localIdx.x\n"
|
||||
"#define GET_GLOBAL_IDX globalIdx.x\n"
|
||||
"#define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()\n"
|
||||
"#define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID\n"
|
||||
"#define AtomInc(x) InterlockedAdd(x, 1)\n"
|
||||
"#define AtomInc1(x, out) InterlockedAdd(x, 1, out)\n"
|
||||
"\n"
|
||||
"#define min2 min\n"
|
||||
"#define max2 max\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"cbuffer CB0 : register( b0 )\n"
|
||||
"{\n"
|
||||
" int m_startBit;\n"
|
||||
" int m_totalBlocks;\n"
|
||||
" int m_nWorkGroupsToExecute;\n"
|
||||
" int m_nBlocksPerGroup;\n"
|
||||
"\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"typedef struct {\n"
|
||||
" unsigned int key;\n"
|
||||
" unsigned int value;\n"
|
||||
"} KeyValuePair;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"StructuredBuffer<u32> rHistogram : register(t0);\n"
|
||||
"\n"
|
||||
"RWStructuredBuffer<KeyValuePair> dataToSort : register( u0 );\n"
|
||||
"RWStructuredBuffer<KeyValuePair> dataToSortOut : register( u1 );\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"#define WG_SIZE 128\n"
|
||||
"#define ELEMENTS_PER_WORK_ITEM 4\n"
|
||||
"#define BITS_PER_PASS 4\n"
|
||||
"#define NUM_BUCKET (1<<BITS_PER_PASS)\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"groupshared u32 sorterSharedMemory[max(WG_SIZE*2*2, WG_SIZE*ELEMENTS_PER_WORK_ITEM*2)];\n"
|
||||
"groupshared u32 localHistogramToCarry[NUM_BUCKET];\n"
|
||||
"groupshared u32 localHistogram[NUM_BUCKET*2];\n"
|
||||
"groupshared u32 localHistogramMat[NUM_BUCKET*WG_SIZE];\n"
|
||||
"groupshared u32 localPrefixSum[NUM_BUCKET];\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"#define SET_LOCAL_SORT_DATA(idx, sortDataIn) sorterSharedMemory[2*(idx)+0] = sortDataIn.key; sorterSharedMemory[2*(idx)+1] = sortDataIn.value; \n"
|
||||
"#define GET_LOCAL_SORT_DATA(idx, sortDataOut) sortDataOut.key = sorterSharedMemory[2*(idx)+0]; sortDataOut.value = sorterSharedMemory[2*(idx)+1];\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"uint4 prefixScanVector( uint4 data )\n"
|
||||
"{\n"
|
||||
" data.y += data.x;\n"
|
||||
" data.w += data.z;\n"
|
||||
" data.z += data.y;\n"
|
||||
" data.w += data.y;\n"
|
||||
" return data;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint prefixScanVectorEx( inout uint4 data )\n"
|
||||
"{\n"
|
||||
" uint4 backup = data;\n"
|
||||
" data.y += data.x;\n"
|
||||
" data.w += data.z;\n"
|
||||
" data.z += data.y;\n"
|
||||
" data.w += data.y;\n"
|
||||
" uint sum = data.w;\n"
|
||||
" data -= backup;\n"
|
||||
" return sum;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint localPrefixScan128( uint pData, uint lIdx, inout uint totalSum )\n"
|
||||
"{\n"
|
||||
" { // Set data\n"
|
||||
" sorterSharedMemory[lIdx] = 0;\n"
|
||||
" sorterSharedMemory[lIdx+WG_SIZE] = pData;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // Prefix sum\n"
|
||||
" int idx = 2*lIdx + (WG_SIZE+1);\n"
|
||||
" if( lIdx < 64 )\n"
|
||||
" {\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-1];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-2]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-4];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-8]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-16];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-32]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-64];\n"
|
||||
" }\n"
|
||||
" if( lIdx < 64 ) sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" totalSum = sorterSharedMemory[WG_SIZE*2-1];\n"
|
||||
" return sorterSharedMemory[lIdx+127];\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"void localPrefixScan128Dual( uint pData0, uint pData1, uint lIdx, \n"
|
||||
" inout uint rank0, inout uint rank1,\n"
|
||||
" inout uint totalSum0, inout uint totalSum1 )\n"
|
||||
"{\n"
|
||||
" { // Set data\n"
|
||||
" sorterSharedMemory[lIdx] = 0;\n"
|
||||
" sorterSharedMemory[lIdx+WG_SIZE] = pData0;\n"
|
||||
" sorterSharedMemory[2*WG_SIZE+lIdx] = 0;\n"
|
||||
" sorterSharedMemory[2*WG_SIZE+lIdx+WG_SIZE] = pData1;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
"// if( lIdx < 128 ) // todo. assert wg size is 128\n"
|
||||
" { // Prefix sum\n"
|
||||
" int blockIdx = lIdx/64;\n"
|
||||
" int groupIdx = lIdx%64;\n"
|
||||
" int idx = 2*groupIdx + (WG_SIZE+1) + (2*WG_SIZE)*blockIdx;\n"
|
||||
"\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-1];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-2];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-4];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-8];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-16];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-32]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-64];\n"
|
||||
"\n"
|
||||
" sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" totalSum0 = sorterSharedMemory[WG_SIZE*2-1];\n"
|
||||
" rank0 = sorterSharedMemory[lIdx+127];\n"
|
||||
" totalSum1 = sorterSharedMemory[2*WG_SIZE+WG_SIZE*2-1];\n"
|
||||
" rank1 = sorterSharedMemory[2*WG_SIZE+lIdx+127];\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint4 localPrefixSum128V( uint4 pData, uint lIdx, inout uint totalSum )\n"
|
||||
"{\n"
|
||||
" { // Set data\n"
|
||||
" sorterSharedMemory[lIdx] = 0;\n"
|
||||
" sorterSharedMemory[lIdx+WG_SIZE] = prefixScanVectorEx( pData );\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // Prefix sum\n"
|
||||
" int idx = 2*lIdx + (WG_SIZE+1);\n"
|
||||
" if( lIdx < 64 )\n"
|
||||
" {\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-1];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-2]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-4];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-8]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-16];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-32]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-64];\n"
|
||||
"\n"
|
||||
" sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" totalSum = sorterSharedMemory[WG_SIZE*2-1];\n"
|
||||
" uint addValue = sorterSharedMemory[lIdx+127];\n"
|
||||
" return pData + uint4(addValue, addValue, addValue, addValue);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"void localPrefixSum128Dual( uint4 pData0, uint4 pData1, uint lIdx, \n"
|
||||
" inout uint4 dataOut0, inout uint4 dataOut1, \n"
|
||||
" inout uint totalSum0, inout uint totalSum1 )\n"
|
||||
"{\n"
|
||||
"/*\n"
|
||||
" dataOut0 = localPrefixSum128V( pData0, lIdx, totalSum0 );\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" dataOut1 = localPrefixSum128V( pData1, lIdx, totalSum1 );\n"
|
||||
" return;\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
" uint4 backup0 = pData0;\n"
|
||||
" uint4 backup1 = pData1;\n"
|
||||
"\n"
|
||||
" { // Prefix sum in a vector\n"
|
||||
" pData0 = prefixScanVector( pData0 );\n"
|
||||
" pData1 = prefixScanVector( pData1 );\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" { // Set data\n"
|
||||
" sorterSharedMemory[lIdx] = 0;\n"
|
||||
" sorterSharedMemory[lIdx+WG_SIZE] = pData0.w;\n"
|
||||
" sorterSharedMemory[2*WG_SIZE+lIdx] = 0;\n"
|
||||
" sorterSharedMemory[2*WG_SIZE+lIdx+WG_SIZE] = pData1.w;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
"// if( lIdx < 128 ) // todo. assert wg size is 128\n"
|
||||
" { // Prefix sum\n"
|
||||
" int blockIdx = lIdx/64;\n"
|
||||
" int groupIdx = lIdx%64;\n"
|
||||
" int idx = 2*groupIdx + (WG_SIZE+1) + (2*WG_SIZE)*blockIdx;\n"
|
||||
"\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-1];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-2];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-4];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-8];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-16];\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-32]; \n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-64];\n"
|
||||
"\n"
|
||||
" sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" totalSum0 = sorterSharedMemory[WG_SIZE*2-1];\n"
|
||||
" {\n"
|
||||
" uint addValue = sorterSharedMemory[lIdx+127];\n"
|
||||
" dataOut0 = pData0 + uint4(addValue, addValue, addValue, addValue) - backup0;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" totalSum1 = sorterSharedMemory[2*WG_SIZE+WG_SIZE*2-1];\n"
|
||||
" {\n"
|
||||
" uint addValue = sorterSharedMemory[2*WG_SIZE+lIdx+127];\n"
|
||||
" dataOut1 = pData1 + uint4(addValue, addValue, addValue, addValue) - backup1;\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint4 extractKeys(uint4 data, uint targetKey)\n"
|
||||
"{\n"
|
||||
" uint4 key;\n"
|
||||
" key.x = data.x == targetKey ? 1:0;\n"
|
||||
" key.y = data.y == targetKey ? 1:0;\n"
|
||||
" key.z = data.z == targetKey ? 1:0;\n"
|
||||
" key.w = data.w == targetKey ? 1:0;\n"
|
||||
" return key;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint4 extractKeysByBits(uint4 data, uint targetKey)\n"
|
||||
"{\n"
|
||||
" uint4 key;\n"
|
||||
" uint mask = 1<<targetKey;\n"
|
||||
" key.x = (data.x & mask) >> targetKey;\n"
|
||||
" key.y = (data.y & mask) >> targetKey;\n"
|
||||
" key.z = (data.z & mask) >> targetKey;\n"
|
||||
" key.w = (data.w & mask) >> targetKey;\n"
|
||||
" return key;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint packKeys(uint lower, uint upper)\n"
|
||||
"{\n"
|
||||
" return lower|(upper<<16);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint4 packKeys(uint4 lower, uint4 upper)\n"
|
||||
"{\n"
|
||||
" return uint4( lower.x|(upper.x<<16), lower.y|(upper.y<<16), lower.z|(upper.z<<16), lower.w|(upper.w<<16) );\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint extractLower( uint data )\n"
|
||||
"{\n"
|
||||
" return data&0xffff;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint extractUpper( uint data )\n"
|
||||
"{\n"
|
||||
" return (data>>16)&0xffff;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint4 extractLower( uint4 data )\n"
|
||||
"{\n"
|
||||
" return uint4( data.x&0xffff, data.y&0xffff, data.z&0xffff, data.w&0xffff );\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint4 extractUpper( uint4 data )\n"
|
||||
"{\n"
|
||||
" return uint4( (data.x>>16)&0xffff, (data.y>>16)&0xffff, (data.z>>16)&0xffff, (data.w>>16)&0xffff );\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void SortAndScatterKernel( DEFAULT_ARGS ) \n"
|
||||
"{\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
"\n"
|
||||
" if( lIdx < (NUM_BUCKET) )\n"
|
||||
" {\n"
|
||||
" localHistogramToCarry[lIdx] = rHistogram[lIdx*m_nWorkGroupsToExecute + wgIdx];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" for(uint igroup=wgIdx*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx+1)*m_nBlocksPerGroup); igroup++)\n"
|
||||
" {\n"
|
||||
" u32 myHistogram;\n"
|
||||
" if( lIdx < (NUM_BUCKET) )\n"
|
||||
" {\n"
|
||||
" localPrefixSum[lIdx] = 0.f;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" u32 newOffset[4];\n"
|
||||
" KeyValuePair myData[4];\n"
|
||||
" { // read data\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
" uint startAddress = igroup*numLocalElements + lIdx*4;\n"
|
||||
"\n"
|
||||
" myData[0] = dataToSort[startAddress+0];\n"
|
||||
" myData[1] = dataToSort[startAddress+1];\n"
|
||||
" myData[2] = dataToSort[startAddress+2];\n"
|
||||
" myData[3] = dataToSort[startAddress+3];\n"
|
||||
"\n"
|
||||
" newOffset[0] = newOffset[1] = newOffset[2] = newOffset[3] = 0;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" int localOffset = 0;\n"
|
||||
" uint4 b = uint4((myData[0].key>>m_startBit) & 0xf, (myData[1].key>>m_startBit) & 0xf, (myData[2].key>>m_startBit) & 0xf, (myData[3].key>>m_startBit) & 0xf);\n"
|
||||
" for(uint targetKey=0; targetKey<(NUM_BUCKET); targetKey+=4)\n"
|
||||
" {\n"
|
||||
" uint4 key[4];\n"
|
||||
" uint keySet[2];\n"
|
||||
" { // pack 4\n"
|
||||
" uint4 scannedKey[4];\n"
|
||||
" key[0] = scannedKey[0] = extractKeys( b, targetKey+0 );\n"
|
||||
" key[1] = scannedKey[1] = extractKeys( b, targetKey+1 );\n"
|
||||
" key[2] = scannedKey[2] = extractKeys( b, targetKey+2 );\n"
|
||||
" key[3] = scannedKey[3] = extractKeys( b, targetKey+3 );\n"
|
||||
" {\n"
|
||||
" uint s[4];\n"
|
||||
" s[0] = prefixScanVectorEx( scannedKey[0] );\n"
|
||||
" s[1] = prefixScanVectorEx( scannedKey[1] );\n"
|
||||
" s[2] = prefixScanVectorEx( scannedKey[2] );\n"
|
||||
" s[3] = prefixScanVectorEx( scannedKey[3] );\n"
|
||||
" keySet[0] = packKeys( s[0], s[1] );\n"
|
||||
" keySet[1] = packKeys( s[2], s[3] );\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" uint dstAddressBase[4];\n"
|
||||
" {\n"
|
||||
"\n"
|
||||
" uint totalSumPacked[2];\n"
|
||||
" uint dstAddressPacked[2];\n"
|
||||
"\n"
|
||||
" localPrefixScan128Dual( keySet[0], keySet[1], lIdx, dstAddressPacked[0], dstAddressPacked[1], totalSumPacked[0], totalSumPacked[1] );\n"
|
||||
"\n"
|
||||
" dstAddressBase[0] = extractLower( dstAddressPacked[0] );\n"
|
||||
" dstAddressBase[1] = extractUpper( dstAddressPacked[0] );\n"
|
||||
" dstAddressBase[2] = extractLower( dstAddressPacked[1] );\n"
|
||||
" dstAddressBase[3] = extractUpper( dstAddressPacked[1] );\n"
|
||||
"\n"
|
||||
" uint4 histogram;\n"
|
||||
" histogram.x = extractLower(totalSumPacked[0]);\n"
|
||||
" histogram.y = extractUpper(totalSumPacked[0]);\n"
|
||||
" histogram.z = extractLower(totalSumPacked[1]);\n"
|
||||
" histogram.w = extractUpper(totalSumPacked[1]);\n"
|
||||
"\n"
|
||||
" if( lIdx == targetKey + 0 ) myHistogram = histogram.x;\n"
|
||||
" else if( lIdx == targetKey + 1 ) myHistogram = histogram.y;\n"
|
||||
" else if( lIdx == targetKey + 2 ) myHistogram = histogram.z;\n"
|
||||
" else if( lIdx == targetKey + 3 ) myHistogram = histogram.w;\n"
|
||||
" \n"
|
||||
" uint histogramSum = prefixScanVectorEx( histogram );\n"
|
||||
"\n"
|
||||
" if( lIdx == targetKey + 0 ) localPrefixSum[targetKey+0] = localOffset+histogram.x;\n"
|
||||
" else if( lIdx == targetKey + 1 ) localPrefixSum[targetKey+1] = localOffset+histogram.y;\n"
|
||||
" else if( lIdx == targetKey + 2 ) localPrefixSum[targetKey+2] = localOffset+histogram.z;\n"
|
||||
" else if( lIdx == targetKey + 3 ) localPrefixSum[targetKey+3] = localOffset+histogram.w;\n"
|
||||
"\n"
|
||||
" localOffset += histogramSum;\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" for(int ie=0; ie<4; ie++)\n"
|
||||
" {\n"
|
||||
" uint4 scannedKey = key[ie];\n"
|
||||
" prefixScanVectorEx( scannedKey );\n"
|
||||
"\n"
|
||||
" uint offset = localPrefixSum[targetKey + ie] + dstAddressBase[ie];\n"
|
||||
" uint4 dstAddress = uint4( offset, offset, offset, offset ) + scannedKey;\n"
|
||||
"\n"
|
||||
" newOffset[0] += dstAddress.x*key[ie].x;\n"
|
||||
" newOffset[1] += dstAddress.y*key[ie].y;\n"
|
||||
" newOffset[2] += dstAddress.z*key[ie].z;\n"
|
||||
" newOffset[3] += dstAddress.w*key[ie].w;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" { // local scatter\n"
|
||||
" SET_LOCAL_SORT_DATA(newOffset[0], myData[0]);\n"
|
||||
" SET_LOCAL_SORT_DATA(newOffset[1], myData[1]);\n"
|
||||
" SET_LOCAL_SORT_DATA(newOffset[2], myData[2]);\n"
|
||||
" SET_LOCAL_SORT_DATA(newOffset[3], myData[3]);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // write data\n"
|
||||
" for(int i=0; i<ELEMENTS_PER_WORK_ITEM; i++)\n"
|
||||
" {\n"
|
||||
" int dataIdx = 4*lIdx+i;\n"
|
||||
" KeyValuePair localData; GET_LOCAL_SORT_DATA( dataIdx, localData );\n"
|
||||
" int binIdx = (localData.key >> m_startBit) & 0xf;\n"
|
||||
" int groupOffset = localHistogramToCarry[binIdx];\n"
|
||||
" int myIdx = dataIdx - localPrefixSum[binIdx];\n"
|
||||
"\n"
|
||||
" dataToSortOut[ groupOffset + myIdx ] = localData;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" if( lIdx < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" localHistogramToCarry[lIdx] += myHistogram;\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void SortAndScatterKernel1( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
"\n"
|
||||
" if( lIdx < (NUM_BUCKET) )\n"
|
||||
" {\n"
|
||||
" localHistogramToCarry[lIdx] = rHistogram[lIdx*m_nWorkGroupsToExecute + wgIdx.x];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" for(uint igroup=wgIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx.x+1)*m_nBlocksPerGroup); igroup++)\n"
|
||||
" {\n"
|
||||
" u32 myHistogram;\n"
|
||||
"\n"
|
||||
" KeyValuePair myData[4];\n"
|
||||
" uint startAddrBlock;\n"
|
||||
" { // read data\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
" startAddrBlock = lIdx*4;\n"
|
||||
" uint startAddress = igroup*numLocalElements + startAddrBlock;\n"
|
||||
"\n"
|
||||
" myData[0] = dataToSort[startAddress+0];\n"
|
||||
" myData[1] = dataToSort[startAddress+1];\n"
|
||||
" myData[2] = dataToSort[startAddress+2];\n"
|
||||
" myData[3] = dataToSort[startAddress+3];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" // local sort\n"
|
||||
" for(int ib=m_startBit; ib<m_startBit+BITS_PER_PASS; ib++)\n"
|
||||
" {\n"
|
||||
" uint4 keys = uint4(~(myData[0].key>>ib) & 0x1, ~(myData[1].key>>ib) & 0x1, ~(myData[2].key>>ib) & 0x1, ~(myData[3].key>>ib) & 0x1);\n"
|
||||
" uint total;\n"
|
||||
" uint4 rankOfP = localPrefixSum128V( keys, lIdx, total );\n"
|
||||
" uint4 rankOfN = uint4(startAddrBlock, startAddrBlock+1, startAddrBlock+2, startAddrBlock+3) - rankOfP + uint4( total, total, total, total );\n"
|
||||
"\n"
|
||||
" uint4 myAddr = (keys==uint4(1,1,1,1))? rankOfP: rankOfN;\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.x, myData[0] );\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.y, myData[1] );\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.z, myData[2] );\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.w, myData[3] );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+0, myData[0] );\n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+1, myData[1] );\n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+2, myData[2] );\n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+3, myData[3] );\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" {// create histogram -> prefix sum\n"
|
||||
" if( lIdx < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" localHistogram[lIdx] = 0;\n"
|
||||
" localHistogram[NUM_BUCKET+lIdx] = 0;\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" uint4 keys = uint4((myData[0].key>>m_startBit) & 0xf, (myData[1].key>>m_startBit) & 0xf, (myData[2].key>>m_startBit) & 0xf, (myData[3].key>>m_startBit) & 0xf);\n"
|
||||
" \n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.x], 1 );\n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.y], 1 );\n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.z], 1 );\n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.w], 1 );\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" uint hIdx = NUM_BUCKET+lIdx;\n"
|
||||
" if( lIdx < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" myHistogram = localHistogram[hIdx];\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" if( lIdx < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" localHistogram[hIdx] = localHistogram[hIdx-1];\n"
|
||||
"\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-1];\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-2];\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-4];\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-8];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
"/*\n"
|
||||
" {// write back\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
" startAddrBlock = lIdx*4;\n"
|
||||
" uint startAddress = igroup*numLocalElements + startAddrBlock;\n"
|
||||
"\n"
|
||||
" for(int ie=0; ie<ELEMENTS_PER_WORK_ITEM; ie++)\n"
|
||||
" {\n"
|
||||
" dataToSortOut[ startAddress+ie ] = myData[ie];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"*/\n"
|
||||
" {\n"
|
||||
" for(int ie=0; ie<ELEMENTS_PER_WORK_ITEM; ie++)\n"
|
||||
" {\n"
|
||||
" int dataIdx = startAddrBlock+ie;\n"
|
||||
" int binIdx = (myData[ie].key>>m_startBit)&0xf;\n"
|
||||
" int groupOffset = localHistogramToCarry[binIdx];\n"
|
||||
" int myIdx = dataIdx - localHistogram[NUM_BUCKET+binIdx];\n"
|
||||
" dataToSortOut[ groupOffset + myIdx ] = myData[ie];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" if( lIdx < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" localHistogramToCarry[lIdx] += myHistogram;\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"/*\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void SortAndScatterKernel1( uint3 gIdx : SV_GroupID, uint3 lIdx : SV_GroupThreadID )\n"
|
||||
"{\n"
|
||||
" if( lIdx.x < (NUM_BUCKET) )\n"
|
||||
" {\n"
|
||||
" localHistogramToCarry[lIdx.x] = rHistogram[lIdx.x*m_nWorkGroupsToExecute + gIdx.x];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" for(uint igroup=gIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(gIdx.x+1)*m_nBlocksPerGroup); igroup++)\n"
|
||||
" {\n"
|
||||
" u32 myHistogram;\n"
|
||||
"\n"
|
||||
" KeyValuePair myData[4];\n"
|
||||
" uint startAddrBlock;\n"
|
||||
" { // read data\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
" startAddrBlock = lIdx.x*4;\n"
|
||||
" uint startAddress = igroup*numLocalElements + startAddrBlock;\n"
|
||||
"\n"
|
||||
" myData[0] = dataToSort[startAddress+0];\n"
|
||||
" myData[1] = dataToSort[startAddress+1];\n"
|
||||
" myData[2] = dataToSort[startAddress+2];\n"
|
||||
" myData[3] = dataToSort[startAddress+3];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" for(int ib=m_startBit; ib<m_startBit+BITS_PER_PASS; ib++)\n"
|
||||
" {\n"
|
||||
" uint4 keys = uint4(~(myData[0].key>>ib) & 0x1, ~(myData[1].key>>ib) & 0x1, ~(myData[2].key>>ib) & 0x1, ~(myData[3].key>>ib) & 0x1);\n"
|
||||
" uint total;\n"
|
||||
" uint4 rankOfP = localPrefixSum128V( keys, lIdx.x, total );\n"
|
||||
" uint4 rankOfN = uint4(startAddrBlock, startAddrBlock+1, startAddrBlock+2, startAddrBlock+3) - rankOfP + uint4( total, total, total, total );\n"
|
||||
"\n"
|
||||
" uint4 myAddr = (keys==uint4(1,1,1,1))? rankOfP: rankOfN;\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.x, myData[0] );\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.y, myData[1] );\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.z, myData[2] );\n"
|
||||
" SET_LOCAL_SORT_DATA( myAddr.w, myData[3] );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+0, myData[0] );\n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+1, myData[1] );\n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+2, myData[2] );\n"
|
||||
" GET_LOCAL_SORT_DATA( startAddrBlock+3, myData[3] );\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" {// create histogram -> prefix sum\n"
|
||||
" if( lIdx.x < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" localHistogram[lIdx.x] = 0;\n"
|
||||
" localHistogram[NUM_BUCKET+lIdx.x] = 0;\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" uint4 keys = uint4((myData[0].key>>m_startBit) & 0xf, (myData[1].key>>m_startBit) & 0xf, (myData[2].key>>m_startBit) & 0xf, (myData[3].key>>m_startBit) & 0xf);\n"
|
||||
" \n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.x], 1 );\n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.y], 1 );\n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.z], 1 );\n"
|
||||
" InterlockedAdd( localHistogram[NUM_BUCKET+keys.w], 1 );\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" uint hIdx = NUM_BUCKET+lIdx.x;\n"
|
||||
" if( lIdx.x < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" myHistogram = localHistogram[hIdx];\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
"\n"
|
||||
" if( lIdx.x < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" localHistogram[hIdx] = localHistogram[hIdx-1];\n"
|
||||
"\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-1];\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-2];\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-4];\n"
|
||||
" localHistogram[hIdx] += localHistogram[hIdx-8];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
" {// write back\n"
|
||||
" for(int ie=0; ie<ELEMENTS_PER_WORK_ITEM; ie++)\n"
|
||||
" {\n"
|
||||
" int dataIdx = startAddrBlock+ie;\n"
|
||||
" int binIdx = (myData[ie].key>>m_startBit)&0xf;\n"
|
||||
" int groupOffset = localHistogramToCarry[binIdx];\n"
|
||||
" int myIdx = dataIdx - localHistogram[NUM_BUCKET+binIdx];\n"
|
||||
" \n"
|
||||
" dataToSortOut[ groupOffset + myIdx ] = myData[ie];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" if( lIdx.x < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" localHistogramToCarry[lIdx.x] += myHistogram;\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
"StructuredBuffer<KeyValuePair> dataToSort1 : register( t0 );\n"
|
||||
"RWStructuredBuffer<u32> wHistogram1 : register(u0);\n"
|
||||
"\n"
|
||||
"#define MY_HISTOGRAM(idx) localHistogramMat[(idx)*WG_SIZE+lIdx.x]\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void StreamCountKernel( DEFAULT_ARGS ) \n"
|
||||
"{\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
"\n"
|
||||
" int myHistogram[NUM_BUCKET];\n"
|
||||
"\n"
|
||||
" for(int i=0; i<NUM_BUCKET; i++)\n"
|
||||
" {\n"
|
||||
" MY_HISTOGRAM(i) = 0;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" for(uint igroup=wgIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx.x+1)*m_nBlocksPerGroup); igroup++)\n"
|
||||
" {\n"
|
||||
" uint localKeys[4];\n"
|
||||
" { // read data\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
"\n"
|
||||
" uint4 localAddress = uint4(lIdx, lIdx, lIdx, lIdx)*4+uint4(0,1,2,3);\n"
|
||||
" uint4 globalAddress = uint4(igroup,igroup,igroup,igroup)*numLocalElements + localAddress;\n"
|
||||
"\n"
|
||||
" KeyValuePair localData0 = dataToSort1[globalAddress.x];\n"
|
||||
" KeyValuePair localData1 = dataToSort1[globalAddress.y];\n"
|
||||
" KeyValuePair localData2 = dataToSort1[globalAddress.z];\n"
|
||||
" KeyValuePair localData3 = dataToSort1[globalAddress.w];\n"
|
||||
"\n"
|
||||
" localKeys[0] = (localData0.key >> m_startBit) & 0xf;\n"
|
||||
" localKeys[1] = (localData1.key >> m_startBit) & 0xf;\n"
|
||||
" localKeys[2] = (localData2.key >> m_startBit) & 0xf;\n"
|
||||
" localKeys[3] = (localData3.key >> m_startBit) & 0xf;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" MY_HISTOGRAM( localKeys[0] )++;\n"
|
||||
" MY_HISTOGRAM( localKeys[1] )++;\n"
|
||||
" MY_HISTOGRAM( localKeys[2] )++;\n"
|
||||
" MY_HISTOGRAM( localKeys[3] )++;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // reduce to 1\n"
|
||||
" if( lIdx < 64 )//WG_SIZE/2 )\n"
|
||||
" {\n"
|
||||
" for(int i=0; i<NUM_BUCKET/2; i++)\n"
|
||||
" {\n"
|
||||
" int idx = lIdx;\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" else if( lIdx < 128 )\n"
|
||||
" {\n"
|
||||
" for(int i=NUM_BUCKET/2; i<NUM_BUCKET; i++)\n"
|
||||
" {\n"
|
||||
" int idx = lIdx-64;\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // write data\n"
|
||||
" if( lIdx < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" wHistogram1[ lIdx*m_nWorkGroupsToExecute + wgIdx.x ] = localHistogramMat[ lIdx*WG_SIZE+0 ];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"/*\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void StreamCountKernel( uint3 gIdx : SV_GroupID, uint3 lIdx : SV_GroupThreadID ) \n"
|
||||
"{\n"
|
||||
" int myHistogram[NUM_BUCKET];\n"
|
||||
"\n"
|
||||
" for(int i=0; i<NUM_BUCKET; i++)\n"
|
||||
" {\n"
|
||||
" myHistogram[i] = 0;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" for(uint igroup=gIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(gIdx.x+1)*m_nBlocksPerGroup); igroup++)\n"
|
||||
" {\n"
|
||||
" uint localKeys[4];\n"
|
||||
" { // read data\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
"\n"
|
||||
" uint4 localAddress = uint4(lIdx.x, lIdx.x, lIdx.x, lIdx.x)*4+uint4(0,1,2,3);\n"
|
||||
" uint4 globalAddress = uint4(igroup,igroup,igroup,igroup)*numLocalElements + localAddress;\n"
|
||||
"\n"
|
||||
" KeyValuePair localData0 = dataToSort1[globalAddress.x];\n"
|
||||
" KeyValuePair localData1 = dataToSort1[globalAddress.y];\n"
|
||||
" KeyValuePair localData2 = dataToSort1[globalAddress.z];\n"
|
||||
" KeyValuePair localData3 = dataToSort1[globalAddress.w];\n"
|
||||
"\n"
|
||||
" localKeys[0] = (localData0.key >> m_startBit) & 0xf;\n"
|
||||
" localKeys[1] = (localData1.key >> m_startBit) & 0xf;\n"
|
||||
" localKeys[2] = (localData2.key >> m_startBit) & 0xf;\n"
|
||||
" localKeys[3] = (localData3.key >> m_startBit) & 0xf;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" myHistogram[ localKeys[0] ]++;\n"
|
||||
" myHistogram[ localKeys[1] ]++;\n"
|
||||
" myHistogram[ localKeys[2] ]++;\n"
|
||||
" myHistogram[ localKeys[3] ]++;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" { // move to shared\n"
|
||||
" for(int i=0; i<NUM_BUCKET; i++)\n"
|
||||
" {\n"
|
||||
" localHistogramMat[i*WG_SIZE+lIdx.x] = myHistogram[i];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // reduce to 1\n"
|
||||
" if( lIdx.x < 64 )//WG_SIZE/2 )\n"
|
||||
" {\n"
|
||||
" for(int i=0; i<NUM_BUCKET/2; i++)\n"
|
||||
" {\n"
|
||||
" int idx = lIdx.x;\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" else if( lIdx.x < 128 )\n"
|
||||
" {\n"
|
||||
" for(int i=NUM_BUCKET/2; i<NUM_BUCKET; i++)\n"
|
||||
" {\n"
|
||||
" int idx = lIdx.x-64;\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+64];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+32];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+16];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+8];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+4];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+2];\n"
|
||||
" localHistogramMat[i*WG_SIZE+idx] += localHistogramMat[i*WG_SIZE+idx+1];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // write data\n"
|
||||
" if( lIdx.x < NUM_BUCKET )\n"
|
||||
" {\n"
|
||||
" wHistogram1[ lIdx.x*m_nWorkGroupsToExecute + gIdx.x ] = localHistogramMat[ lIdx.x*WG_SIZE+0 ];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
"/*\n"
|
||||
"// for MAX_WG_SIZE 20\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void PrefixScanKernel( uint3 gIdx : SV_GroupID, uint3 lIdx : SV_GroupThreadID ) \n"
|
||||
"{\n"
|
||||
" uint4 myData = uint4(0,0,0,0);\n"
|
||||
" if( 4*lIdx.x+0 < NUM_BUCKET*m_nWorkGroupsToExecute )\n"
|
||||
" myData.x = wHistogram1[4*lIdx.x+0];\n"
|
||||
" if( 4*lIdx.x+1 < NUM_BUCKET*m_nWorkGroupsToExecute )\n"
|
||||
" myData.y = wHistogram1[4*lIdx.x+1];\n"
|
||||
" if( 4*lIdx.x+2 < NUM_BUCKET*m_nWorkGroupsToExecute )\n"
|
||||
" myData.z = wHistogram1[4*lIdx.x+2];\n"
|
||||
" if( 4*lIdx.x+3 < NUM_BUCKET*m_nWorkGroupsToExecute )\n"
|
||||
" myData.w = wHistogram1[4*lIdx.x+3];\n"
|
||||
"\n"
|
||||
" uint totalSum;\n"
|
||||
"\n"
|
||||
" uint4 scanned = localPrefixSum128V( myData, lIdx.x, totalSum );\n"
|
||||
"\n"
|
||||
" wHistogram1[4*lIdx.x+0] = scanned.x;\n"
|
||||
" wHistogram1[4*lIdx.x+1] = scanned.y;\n"
|
||||
" wHistogram1[4*lIdx.x+2] = scanned.z;\n"
|
||||
" wHistogram1[4*lIdx.x+3] = scanned.w;\n"
|
||||
"}\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
"// for MAX_WG_SIZE 80\n"
|
||||
"// can hold up to WG_SIZE*12 (128*12 > 80*16 )\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void PrefixScanKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
"\n"
|
||||
" uint data[12] = {0,0,0,0,0,0,0,0,0,0,0,0};\n"
|
||||
" for(int i=0; i<12; i++)\n"
|
||||
" {\n"
|
||||
" if( int(12*lIdx+i) < NUM_BUCKET*m_nWorkGroupsToExecute )\n"
|
||||
" data[i] = wHistogram1[12*lIdx+i];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" uint4 myData = uint4(0,0,0,0);\n"
|
||||
" myData.x = data[0] + data[1];\n"
|
||||
" myData.y = data[2] + data[3];\n"
|
||||
" myData.z = data[4] + data[5];\n"
|
||||
" myData.w = data[6] + data[7];\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" uint totalSum;\n"
|
||||
" uint4 scanned = localPrefixSum128V( myData, lIdx, totalSum );\n"
|
||||
"\n"
|
||||
" data[11] = scanned.w + data[9] + data[10];\n"
|
||||
" data[10] = scanned.w + data[9];\n"
|
||||
" data[9] = scanned.w;\n"
|
||||
" data[8] = scanned.z + data[6] + data[7];\n"
|
||||
" data[7] = scanned.z + data[6];\n"
|
||||
" data[6] = scanned.z;\n"
|
||||
" data[5] = scanned.y + data[3] + data[4];\n"
|
||||
" data[4] = scanned.y + data[3];\n"
|
||||
" data[3] = scanned.y;\n"
|
||||
" data[2] = scanned.x + data[0] + data[1];\n"
|
||||
" data[1] = scanned.x + data[0];\n"
|
||||
" data[0] = scanned.x;\n"
|
||||
"\n"
|
||||
" for(int i=0; i<12; i++)\n"
|
||||
" {\n"
|
||||
" wHistogram1[12*lIdx+i] = data[i];\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"/*\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void PrefixScanKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
"\n"
|
||||
" uint data[8] = {0,0,0,0,0,0,0,0};\n"
|
||||
" for(int i=0; i<8; i++)\n"
|
||||
" {\n"
|
||||
" if( int(8*lIdx+i) < NUM_BUCKET*m_nWorkGroupsToExecute )\n"
|
||||
" data[i] = wHistogram1[8*lIdx+i];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" uint4 myData = uint4(0,0,0,0);\n"
|
||||
" myData.x = data[0] + data[1];\n"
|
||||
" myData.y = data[2] + data[3];\n"
|
||||
" myData.z = data[4] + data[5];\n"
|
||||
" myData.w = data[6] + data[7];\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" uint totalSum;\n"
|
||||
" uint4 scanned = localPrefixSum128V( myData, lIdx, totalSum );\n"
|
||||
"\n"
|
||||
" data[7] = scanned.w + data[6];\n"
|
||||
" data[6] = scanned.w;// + data[5];\n"
|
||||
" data[5] = scanned.z + data[4];\n"
|
||||
" data[4] = scanned.z;// + data[3];\n"
|
||||
" data[3] = scanned.y + data[2];\n"
|
||||
" data[2] = scanned.y;// + data[1];\n"
|
||||
" data[1] = scanned.x + data[0];\n"
|
||||
" data[0] = scanned.x;\n"
|
||||
"\n"
|
||||
" for(int i=0; i<8; i++)\n"
|
||||
" {\n"
|
||||
" wHistogram1[8*lIdx+i] = data[i];\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void CopyKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
"\n"
|
||||
" for(uint igroup=wgIdx.x*m_nBlocksPerGroup; igroup<min2(m_totalBlocks,(wgIdx.x+1)*m_nBlocksPerGroup); igroup++)\n"
|
||||
" {\n"
|
||||
" KeyValuePair myData[4];\n"
|
||||
" uint startAddrBlock;\n"
|
||||
" { // read data\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
" startAddrBlock = lIdx*4;\n"
|
||||
" uint startAddress = igroup*numLocalElements + startAddrBlock;\n"
|
||||
"\n"
|
||||
" myData[0] = dataToSort[startAddress+0];\n"
|
||||
" myData[1] = dataToSort[startAddress+1];\n"
|
||||
" myData[2] = dataToSort[startAddress+2];\n"
|
||||
" myData[3] = dataToSort[startAddress+3];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" {\n"
|
||||
" int numLocalElements = WG_SIZE*ELEMENTS_PER_WORK_ITEM;\n"
|
||||
" uint startAddress = igroup*numLocalElements + startAddrBlock;\n"
|
||||
"\n"
|
||||
" dataToSortOut[startAddress+0] = myData[0];\n"
|
||||
" dataToSortOut[startAddress+1] = myData[1];\n"
|
||||
" dataToSortOut[startAddress+2] = myData[2];\n"
|
||||
" dataToSortOut[startAddress+3] = myData[3];\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
;
|
||||
@@ -0,0 +1,93 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
template<>
|
||||
class RadixSort<TYPE_HOST> : public RadixSortBase
|
||||
{
|
||||
public:
|
||||
struct Data
|
||||
{
|
||||
HostBuffer<SortData>* m_workBuffer;
|
||||
};
|
||||
|
||||
enum
|
||||
{
|
||||
BITS_PER_PASS = 8,
|
||||
NUM_TABLES = (1<<BITS_PER_PASS),
|
||||
};
|
||||
|
||||
static
|
||||
Data* allocate(const Device* deviceData, int maxSize, Option option = SORT_STANDARD)
|
||||
{
|
||||
ADLASSERT( deviceData->m_type == TYPE_HOST );
|
||||
|
||||
Data* data = new Data;
|
||||
data->m_workBuffer = new HostBuffer<SortData>( deviceData, maxSize );
|
||||
return data;
|
||||
}
|
||||
|
||||
static
|
||||
void deallocate(Data* data)
|
||||
{
|
||||
delete data->m_workBuffer;
|
||||
delete data;
|
||||
}
|
||||
|
||||
static
|
||||
void execute(Data* data, Buffer<SortData>& inout, int n, int sortBits = 32)
|
||||
{
|
||||
ADLASSERT( inout.getType() == TYPE_HOST );
|
||||
|
||||
int tables[NUM_TABLES];
|
||||
int counter[NUM_TABLES];
|
||||
|
||||
SortData* src = inout.m_ptr;
|
||||
SortData* dst = data->m_workBuffer->m_ptr;
|
||||
|
||||
int count=0;
|
||||
for(int startBit=0; startBit<sortBits; startBit+=BITS_PER_PASS)
|
||||
{
|
||||
for(int i=0; i<NUM_TABLES; i++)
|
||||
{
|
||||
tables[i] = 0;
|
||||
}
|
||||
|
||||
for(int i=0; i<n; i++)
|
||||
{
|
||||
int tableIdx = (src[i].m_key >> startBit) & (NUM_TABLES-1);
|
||||
tables[tableIdx]++;
|
||||
}
|
||||
|
||||
// prefix scan
|
||||
int sum = 0;
|
||||
for(int i=0; i<NUM_TABLES; i++)
|
||||
{
|
||||
int iData = tables[i];
|
||||
tables[i] = sum;
|
||||
sum += iData;
|
||||
counter[i] = 0;
|
||||
}
|
||||
|
||||
// distribute
|
||||
for(int i=0; i<n; i++)
|
||||
{
|
||||
int tableIdx = (src[i].m_key >> startBit) & (NUM_TABLES-1);
|
||||
|
||||
dst[tables[tableIdx] + counter[tableIdx]] = src[i];
|
||||
counter[tableIdx] ++;
|
||||
}
|
||||
|
||||
swap2( src, dst );
|
||||
count++;
|
||||
}
|
||||
|
||||
{
|
||||
if (count&1)
|
||||
//if( src != inout.m_ptr )
|
||||
{
|
||||
memcpy( dst, src, sizeof(SortData)*n );
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,134 @@
|
||||
static const char* radixSortSimpleKernelsCL = \
|
||||
"#pragma OPENCL EXTENSION cl_amd_printf : enable\n"
|
||||
"#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable\n"
|
||||
"\n"
|
||||
"typedef unsigned int u32;\n"
|
||||
"#define GET_GROUP_IDX get_group_id(0)\n"
|
||||
"#define GET_LOCAL_IDX get_local_id(0)\n"
|
||||
"#define GET_GLOBAL_IDX get_global_id(0)\n"
|
||||
"#define GET_GROUP_SIZE get_local_size(0)\n"
|
||||
"#define GROUP_LDS_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n"
|
||||
"#define AtomInc(x) atom_inc(&(x))\n"
|
||||
"#define AtomInc1(x, out) out = atom_inc(&(x))\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"#define WG_SIZE 128\n"
|
||||
"#define NUM_PER_WI 4\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_key;\n"
|
||||
" u32 m_value;\n"
|
||||
"}SortData;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_startBit;\n"
|
||||
" u32 m_numGroups;\n"
|
||||
" u32 m_padding[2];\n"
|
||||
"} ConstBuffer;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"__kernel\n"
|
||||
"__attribute__((reqd_work_group_size(WG_SIZE,1,1)))\n"
|
||||
"void LocalCountKernel(__global SortData* sortData,\n"
|
||||
" __global u32* ldsHistogramOut,\n"
|
||||
" ConstBuffer cb)\n"
|
||||
"{\n"
|
||||
" __local u32 ldsHistogram[16][256];\n"
|
||||
"\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
"\n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" ldsHistogram[i][lIdx] = 0.f;\n"
|
||||
" ldsHistogram[i][lIdx+128] = 0.f;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" datas[0].m_key = (datas[0].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" datas[1].m_key = (datas[1].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" datas[2].m_key = (datas[2].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" datas[3].m_key = (datas[3].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int tableIdx = lIdx%16;\n"
|
||||
"\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[0].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[1].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[2].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[3].m_key]);\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" u32 sum0, sum1;\n"
|
||||
" sum0 = sum1 = 0;\n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" sum0 += ldsHistogram[i][lIdx];\n"
|
||||
" sum1 += ldsHistogram[i][lIdx+128];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" ldsHistogramOut[lIdx*cb.m_numGroups+GET_GROUP_IDX] = sum0;\n"
|
||||
" ldsHistogramOut[(lIdx+128)*cb.m_numGroups+GET_GROUP_IDX] = sum1;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel\n"
|
||||
"__attribute__((reqd_work_group_size(WG_SIZE,1,1)))\n"
|
||||
"void ScatterKernel(__global SortData* sortData,\n"
|
||||
" __global SortData* sortDataOut,\n"
|
||||
" __global u32* scannedHistogram,\n"
|
||||
" ConstBuffer cb)\n"
|
||||
"{\n"
|
||||
" __local u32 ldsCurrentLocation[256];\n"
|
||||
"\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
"\n"
|
||||
" {\n"
|
||||
" ldsCurrentLocation[lIdx] = scannedHistogram[lIdx*cb.m_numGroups+GET_GROUP_IDX];\n"
|
||||
" ldsCurrentLocation[lIdx+128] = scannedHistogram[(lIdx+128)*cb.m_numGroups+GET_GROUP_IDX];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" int keys[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" keys[0] = (datas[0].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" keys[1] = (datas[1].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" keys[2] = (datas[2].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" keys[3] = (datas[3].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int dst[NUM_PER_WI];\n"
|
||||
" for(int i=0; i<WG_SIZE; i++)\n"
|
||||
" {\n"
|
||||
" if( i==lIdx )\n"
|
||||
" {\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[0]], dst[0]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[1]], dst[1]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[2]], dst[2]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[3]], dst[3]);\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
" sortDataOut[dst[0]] = datas[0];\n"
|
||||
" sortDataOut[dst[1]] = datas[1];\n"
|
||||
" sortDataOut[dst[2]] = datas[2];\n"
|
||||
" sortDataOut[dst[3]] = datas[3];\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"";
|
||||
@@ -0,0 +1,131 @@
|
||||
static const char* radixSortSimpleKernelsDX11 = \
|
||||
"typedef uint u32;\n"
|
||||
"\n"
|
||||
"#define GET_GROUP_IDX groupIdx.x\n"
|
||||
"#define GET_LOCAL_IDX localIdx.x\n"
|
||||
"#define GET_GLOBAL_IDX globalIdx.x\n"
|
||||
"#define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()\n"
|
||||
"#define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID\n"
|
||||
"#define AtomInc(x) InterlockedAdd(x, 1)\n"
|
||||
"#define AtomInc1(x, out) InterlockedAdd(x, 1, out)\n"
|
||||
"\n"
|
||||
"// takahiro end\n"
|
||||
"#define WG_SIZE 128\n"
|
||||
"#define NUM_PER_WI 4\n"
|
||||
"\n"
|
||||
"#define GET_GROUP_SIZE WG_SIZE\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_key;\n"
|
||||
" u32 m_value;\n"
|
||||
"}SortData;\n"
|
||||
"\n"
|
||||
"cbuffer SortCB : register( b0 )\n"
|
||||
"{\n"
|
||||
" u32 m_startBit;\n"
|
||||
" u32 m_numGroups;\n"
|
||||
" u32 m_padding[2];\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"StructuredBuffer<SortData> sortData : register( t0 );\n"
|
||||
"RWStructuredBuffer<u32> ldsHistogramOut : register( u0 );\n"
|
||||
"\n"
|
||||
"groupshared u32 ldsHistogram[16][256];\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void LocalCountKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
"\n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" ldsHistogram[i][lIdx] = 0.f;\n"
|
||||
" ldsHistogram[i][lIdx+128] = 0.f;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" datas[0].m_key = (datas[0].m_key >> m_startBit) & 0xff;\n"
|
||||
" datas[1].m_key = (datas[1].m_key >> m_startBit) & 0xff;\n"
|
||||
" datas[2].m_key = (datas[2].m_key >> m_startBit) & 0xff;\n"
|
||||
" datas[3].m_key = (datas[3].m_key >> m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int tableIdx = lIdx%16;\n"
|
||||
"\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[0].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[1].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[2].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[3].m_key]);\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" u32 sum0, sum1;\n"
|
||||
" sum0 = sum1 = 0;\n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" sum0 += ldsHistogram[i][lIdx];\n"
|
||||
" sum1 += ldsHistogram[i][lIdx+128];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" ldsHistogramOut[lIdx*m_numGroups+GET_GROUP_IDX] = sum0;\n"
|
||||
" ldsHistogramOut[(lIdx+128)*m_numGroups+GET_GROUP_IDX] = sum1;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"RWStructuredBuffer<SortData> sortDataOut : register( u0 );\n"
|
||||
"RWStructuredBuffer<u32> scannedHistogram : register( u1 );\n"
|
||||
"\n"
|
||||
"groupshared u32 ldsCurrentLocation[256];\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void ScatterKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
"\n"
|
||||
" {\n"
|
||||
" ldsCurrentLocation[lIdx] = scannedHistogram[lIdx*m_numGroups+GET_GROUP_IDX];\n"
|
||||
" ldsCurrentLocation[lIdx+128] = scannedHistogram[(lIdx+128)*m_numGroups+GET_GROUP_IDX];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" int keys[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" keys[0] = (datas[0].m_key >> m_startBit) & 0xff;\n"
|
||||
" keys[1] = (datas[1].m_key >> m_startBit) & 0xff;\n"
|
||||
" keys[2] = (datas[2].m_key >> m_startBit) & 0xff;\n"
|
||||
" keys[3] = (datas[3].m_key >> m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int dst[NUM_PER_WI];\n"
|
||||
" for(int i=0; i<WG_SIZE; i++)\n"
|
||||
"// for(int i=0; i<m_padding[0]; i++) // to reduce compile time\n"
|
||||
" {\n"
|
||||
" if( i==lIdx )\n"
|
||||
" {\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[0]], dst[0]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[1]], dst[1]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[2]], dst[2]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[3]], dst[3]);\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
" sortDataOut[dst[0]] = datas[0];\n"
|
||||
" sortDataOut[dst[1]] = datas[1];\n"
|
||||
" sortDataOut[dst[2]] = datas[2];\n"
|
||||
" sortDataOut[dst[3]] = datas[3];\n"
|
||||
"}\n"
|
||||
"";
|
||||
@@ -0,0 +1,147 @@
|
||||
/*
|
||||
Bullet Continuous Collision Detection and Physics Library
|
||||
Copyright (c) 2011 Advanced Micro Devices, Inc. http://bulletphysics.org
|
||||
|
||||
This software is provided 'as-is', without any express or implied warranty.
|
||||
In no event will the authors be held liable for any damages arising from the use of this software.
|
||||
Permission is granted to anyone to use this software for any purpose,
|
||||
including commercial applications, and to alter it and redistribute it freely,
|
||||
subject to the following restrictions:
|
||||
|
||||
1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required.
|
||||
2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software.
|
||||
3. This notice may not be removed or altered from any source distribution.
|
||||
*/
|
||||
//Author Takahiro Harada
|
||||
|
||||
#pragma OPENCL EXTENSION cl_amd_printf : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable
|
||||
|
||||
typedef unsigned int u32;
|
||||
#define GET_GROUP_IDX get_group_id(0)
|
||||
#define GET_LOCAL_IDX get_local_id(0)
|
||||
#define GET_GLOBAL_IDX get_global_id(0)
|
||||
#define GET_GROUP_SIZE get_local_size(0)
|
||||
#define GROUP_LDS_BARRIER barrier(CLK_LOCAL_MEM_FENCE)
|
||||
#define AtomInc(x) atom_inc(&(x))
|
||||
#define AtomInc1(x, out) out = atom_inc(&(x))
|
||||
|
||||
|
||||
#define WG_SIZE 128
|
||||
#define NUM_PER_WI 4
|
||||
|
||||
|
||||
typedef struct
|
||||
{
|
||||
u32 m_key;
|
||||
u32 m_value;
|
||||
}SortData;
|
||||
|
||||
|
||||
typedef struct
|
||||
{
|
||||
u32 m_startBit;
|
||||
u32 m_numGroups;
|
||||
u32 m_padding[2];
|
||||
} ConstBuffer;
|
||||
|
||||
|
||||
__kernel
|
||||
__attribute__((reqd_work_group_size(WG_SIZE,1,1)))
|
||||
void LocalCountKernel(__global SortData* sortData,
|
||||
__global u32* ldsHistogramOut,
|
||||
ConstBuffer cb)
|
||||
{
|
||||
__local u32 ldsHistogram[16][256];
|
||||
|
||||
int lIdx = GET_LOCAL_IDX;
|
||||
int gIdx = GET_GLOBAL_IDX;
|
||||
|
||||
for(int i=0; i<16; i++)
|
||||
{
|
||||
ldsHistogram[i][lIdx] = 0.f;
|
||||
ldsHistogram[i][lIdx+128] = 0.f;
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
SortData datas[NUM_PER_WI];
|
||||
datas[0] = sortData[gIdx*NUM_PER_WI+0];
|
||||
datas[1] = sortData[gIdx*NUM_PER_WI+1];
|
||||
datas[2] = sortData[gIdx*NUM_PER_WI+2];
|
||||
datas[3] = sortData[gIdx*NUM_PER_WI+3];
|
||||
|
||||
datas[0].m_key = (datas[0].m_key >> cb.m_startBit) & 0xff;
|
||||
datas[1].m_key = (datas[1].m_key >> cb.m_startBit) & 0xff;
|
||||
datas[2].m_key = (datas[2].m_key >> cb.m_startBit) & 0xff;
|
||||
datas[3].m_key = (datas[3].m_key >> cb.m_startBit) & 0xff;
|
||||
|
||||
int tableIdx = lIdx%16;
|
||||
|
||||
AtomInc(ldsHistogram[tableIdx][datas[0].m_key]);
|
||||
AtomInc(ldsHistogram[tableIdx][datas[1].m_key]);
|
||||
AtomInc(ldsHistogram[tableIdx][datas[2].m_key]);
|
||||
AtomInc(ldsHistogram[tableIdx][datas[3].m_key]);
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
u32 sum0, sum1;
|
||||
sum0 = sum1 = 0;
|
||||
for(int i=0; i<16; i++)
|
||||
{
|
||||
sum0 += ldsHistogram[i][lIdx];
|
||||
sum1 += ldsHistogram[i][lIdx+128];
|
||||
}
|
||||
|
||||
ldsHistogramOut[lIdx*cb.m_numGroups+GET_GROUP_IDX] = sum0;
|
||||
ldsHistogramOut[(lIdx+128)*cb.m_numGroups+GET_GROUP_IDX] = sum1;
|
||||
}
|
||||
|
||||
__kernel
|
||||
__attribute__((reqd_work_group_size(WG_SIZE,1,1)))
|
||||
void ScatterKernel(__global SortData* sortData,
|
||||
__global SortData* sortDataOut,
|
||||
__global u32* scannedHistogram,
|
||||
ConstBuffer cb)
|
||||
{
|
||||
__local u32 ldsCurrentLocation[256];
|
||||
|
||||
int lIdx = GET_LOCAL_IDX;
|
||||
int gIdx = GET_GLOBAL_IDX;
|
||||
|
||||
{
|
||||
ldsCurrentLocation[lIdx] = scannedHistogram[lIdx*cb.m_numGroups+GET_GROUP_IDX];
|
||||
ldsCurrentLocation[lIdx+128] = scannedHistogram[(lIdx+128)*cb.m_numGroups+GET_GROUP_IDX];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
SortData datas[NUM_PER_WI];
|
||||
int keys[NUM_PER_WI];
|
||||
datas[0] = sortData[gIdx*NUM_PER_WI+0];
|
||||
datas[1] = sortData[gIdx*NUM_PER_WI+1];
|
||||
datas[2] = sortData[gIdx*NUM_PER_WI+2];
|
||||
datas[3] = sortData[gIdx*NUM_PER_WI+3];
|
||||
|
||||
keys[0] = (datas[0].m_key >> cb.m_startBit) & 0xff;
|
||||
keys[1] = (datas[1].m_key >> cb.m_startBit) & 0xff;
|
||||
keys[2] = (datas[2].m_key >> cb.m_startBit) & 0xff;
|
||||
keys[3] = (datas[3].m_key >> cb.m_startBit) & 0xff;
|
||||
|
||||
int dst[NUM_PER_WI];
|
||||
for(int i=0; i<WG_SIZE; i++)
|
||||
{
|
||||
if( i==lIdx )
|
||||
{
|
||||
AtomInc1(ldsCurrentLocation[keys[0]], dst[0]);
|
||||
AtomInc1(ldsCurrentLocation[keys[1]], dst[1]);
|
||||
AtomInc1(ldsCurrentLocation[keys[2]], dst[2]);
|
||||
AtomInc1(ldsCurrentLocation[keys[3]], dst[3]);
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
}
|
||||
sortDataOut[dst[0]] = datas[0];
|
||||
sortDataOut[dst[1]] = datas[1];
|
||||
sortDataOut[dst[2]] = datas[2];
|
||||
sortDataOut[dst[3]] = datas[3];
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
typedef uint u32;
|
||||
|
||||
#define GET_GROUP_IDX groupIdx.x
|
||||
#define GET_LOCAL_IDX localIdx.x
|
||||
#define GET_GLOBAL_IDX globalIdx.x
|
||||
#define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()
|
||||
#define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID
|
||||
#define AtomInc(x) InterlockedAdd(x, 1)
|
||||
#define AtomInc1(x, out) InterlockedAdd(x, 1, out)
|
||||
|
||||
// takahiro end
|
||||
#define WG_SIZE 128
|
||||
#define NUM_PER_WI 4
|
||||
|
||||
#define GET_GROUP_SIZE WG_SIZE
|
||||
|
||||
typedef struct
|
||||
{
|
||||
u32 m_key;
|
||||
u32 m_value;
|
||||
}SortData;
|
||||
|
||||
cbuffer SortCB : register( b0 )
|
||||
{
|
||||
u32 m_startBit;
|
||||
u32 m_numGroups;
|
||||
u32 m_padding[2];
|
||||
};
|
||||
|
||||
StructuredBuffer<SortData> sortData : register( t0 );
|
||||
RWStructuredBuffer<u32> ldsHistogramOut : register( u0 );
|
||||
|
||||
groupshared u32 ldsHistogram[16][256];
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void LocalCountKernel( DEFAULT_ARGS )
|
||||
{
|
||||
int lIdx = GET_LOCAL_IDX;
|
||||
int gIdx = GET_GLOBAL_IDX;
|
||||
|
||||
for(int i=0; i<16; i++)
|
||||
{
|
||||
ldsHistogram[i][lIdx] = 0.f;
|
||||
ldsHistogram[i][lIdx+128] = 0.f;
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
SortData datas[NUM_PER_WI];
|
||||
datas[0] = sortData[gIdx*NUM_PER_WI+0];
|
||||
datas[1] = sortData[gIdx*NUM_PER_WI+1];
|
||||
datas[2] = sortData[gIdx*NUM_PER_WI+2];
|
||||
datas[3] = sortData[gIdx*NUM_PER_WI+3];
|
||||
|
||||
datas[0].m_key = (datas[0].m_key >> m_startBit) & 0xff;
|
||||
datas[1].m_key = (datas[1].m_key >> m_startBit) & 0xff;
|
||||
datas[2].m_key = (datas[2].m_key >> m_startBit) & 0xff;
|
||||
datas[3].m_key = (datas[3].m_key >> m_startBit) & 0xff;
|
||||
|
||||
int tableIdx = lIdx%16;
|
||||
|
||||
AtomInc(ldsHistogram[tableIdx][datas[0].m_key]);
|
||||
AtomInc(ldsHistogram[tableIdx][datas[1].m_key]);
|
||||
AtomInc(ldsHistogram[tableIdx][datas[2].m_key]);
|
||||
AtomInc(ldsHistogram[tableIdx][datas[3].m_key]);
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
u32 sum0, sum1;
|
||||
sum0 = sum1 = 0;
|
||||
for(int i=0; i<16; i++)
|
||||
{
|
||||
sum0 += ldsHistogram[i][lIdx];
|
||||
sum1 += ldsHistogram[i][lIdx+128];
|
||||
}
|
||||
|
||||
ldsHistogramOut[lIdx*m_numGroups+GET_GROUP_IDX] = sum0;
|
||||
ldsHistogramOut[(lIdx+128)*m_numGroups+GET_GROUP_IDX] = sum1;
|
||||
}
|
||||
|
||||
|
||||
RWStructuredBuffer<SortData> sortDataOut : register( u0 );
|
||||
RWStructuredBuffer<u32> scannedHistogram : register( u1 );
|
||||
|
||||
groupshared u32 ldsCurrentLocation[256];
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void ScatterKernel( DEFAULT_ARGS )
|
||||
{
|
||||
int lIdx = GET_LOCAL_IDX;
|
||||
int gIdx = GET_GLOBAL_IDX;
|
||||
|
||||
{
|
||||
ldsCurrentLocation[lIdx] = scannedHistogram[lIdx*m_numGroups+GET_GROUP_IDX];
|
||||
ldsCurrentLocation[lIdx+128] = scannedHistogram[(lIdx+128)*m_numGroups+GET_GROUP_IDX];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
SortData datas[NUM_PER_WI];
|
||||
int keys[NUM_PER_WI];
|
||||
datas[0] = sortData[gIdx*NUM_PER_WI+0];
|
||||
datas[1] = sortData[gIdx*NUM_PER_WI+1];
|
||||
datas[2] = sortData[gIdx*NUM_PER_WI+2];
|
||||
datas[3] = sortData[gIdx*NUM_PER_WI+3];
|
||||
|
||||
keys[0] = (datas[0].m_key >> m_startBit) & 0xff;
|
||||
keys[1] = (datas[1].m_key >> m_startBit) & 0xff;
|
||||
keys[2] = (datas[2].m_key >> m_startBit) & 0xff;
|
||||
keys[3] = (datas[3].m_key >> m_startBit) & 0xff;
|
||||
|
||||
int dst[NUM_PER_WI];
|
||||
for(int i=0; i<WG_SIZE; i++)
|
||||
// for(int i=0; i<m_padding[0]; i++) // to reduce compile time
|
||||
{
|
||||
if( i==lIdx )
|
||||
{
|
||||
AtomInc1(ldsCurrentLocation[keys[0]], dst[0]);
|
||||
AtomInc1(ldsCurrentLocation[keys[1]], dst[1]);
|
||||
AtomInc1(ldsCurrentLocation[keys[2]], dst[2]);
|
||||
AtomInc1(ldsCurrentLocation[keys[3]], dst[3]);
|
||||
}
|
||||
GROUP_LDS_BARRIER;
|
||||
}
|
||||
sortDataOut[dst[0]] = datas[0];
|
||||
sortDataOut[dst[1]] = datas[1];
|
||||
sortDataOut[dst[2]] = datas[2];
|
||||
sortDataOut[dst[3]] = datas[3];
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
static const char* radixSortSimpleKernelsCL= \
|
||||
"/*\n"
|
||||
"Bullet Continuous Collision Detection and Physics Library\n"
|
||||
"Copyright (c) 2011 Advanced Micro Devices, Inc. http://bulletphysics.org\n"
|
||||
"\n"
|
||||
"This software is provided 'as-is', without any express or implied warranty.\n"
|
||||
"In no event will the authors be held liable for any damages arising from the use of this software.\n"
|
||||
"Permission is granted to anyone to use this software for any purpose, \n"
|
||||
"including commercial applications, and to alter it and redistribute it freely, \n"
|
||||
"subject to the following restrictions:\n"
|
||||
"\n"
|
||||
"1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required.\n"
|
||||
"2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software.\n"
|
||||
"3. This notice may not be removed or altered from any source distribution.\n"
|
||||
"*/\n"
|
||||
"//Author Takahiro Harada\n"
|
||||
"\n"
|
||||
"#pragma OPENCL EXTENSION cl_amd_printf : enable\n"
|
||||
"#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable\n"
|
||||
"\n"
|
||||
"typedef unsigned int u32;\n"
|
||||
"#define GET_GROUP_IDX get_group_id(0)\n"
|
||||
"#define GET_LOCAL_IDX get_local_id(0)\n"
|
||||
"#define GET_GLOBAL_IDX get_global_id(0)\n"
|
||||
"#define GET_GROUP_SIZE get_local_size(0)\n"
|
||||
"#define GROUP_LDS_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n"
|
||||
"#define AtomInc(x) atom_inc(&(x))\n"
|
||||
"#define AtomInc1(x, out) out = atom_inc(&(x))\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"#define WG_SIZE 128\n"
|
||||
"#define NUM_PER_WI 4\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_key; \n"
|
||||
" u32 m_value;\n"
|
||||
"}SortData;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_startBit;\n"
|
||||
" u32 m_numGroups;\n"
|
||||
" u32 m_padding[2];\n"
|
||||
"} ConstBuffer;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"__kernel\n"
|
||||
"__attribute__((reqd_work_group_size(WG_SIZE,1,1)))\n"
|
||||
"void LocalCountKernel(__global SortData* sortData, \n"
|
||||
" __global u32* ldsHistogramOut,\n"
|
||||
" ConstBuffer cb)\n"
|
||||
"{\n"
|
||||
" __local u32 ldsHistogram[16][256];\n"
|
||||
"\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
" \n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" ldsHistogram[i][lIdx] = 0.f;\n"
|
||||
" ldsHistogram[i][lIdx+128] = 0.f;\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" datas[0].m_key = (datas[0].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" datas[1].m_key = (datas[1].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" datas[2].m_key = (datas[2].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" datas[3].m_key = (datas[3].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int tableIdx = lIdx%16;\n"
|
||||
" \n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[0].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[1].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[2].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[3].m_key]);\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" u32 sum0, sum1;\n"
|
||||
" sum0 = sum1 = 0;\n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" sum0 += ldsHistogram[i][lIdx];\n"
|
||||
" sum1 += ldsHistogram[i][lIdx+128];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" ldsHistogramOut[lIdx*cb.m_numGroups+GET_GROUP_IDX] = sum0;\n"
|
||||
" ldsHistogramOut[(lIdx+128)*cb.m_numGroups+GET_GROUP_IDX] = sum1;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel\n"
|
||||
"__attribute__((reqd_work_group_size(WG_SIZE,1,1)))\n"
|
||||
"void ScatterKernel(__global SortData* sortData,\n"
|
||||
" __global SortData* sortDataOut,\n"
|
||||
" __global u32* scannedHistogram, \n"
|
||||
" ConstBuffer cb)\n"
|
||||
"{\n"
|
||||
" __local u32 ldsCurrentLocation[256];\n"
|
||||
"\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
" \n"
|
||||
" {\n"
|
||||
" ldsCurrentLocation[lIdx] = scannedHistogram[lIdx*cb.m_numGroups+GET_GROUP_IDX];\n"
|
||||
" ldsCurrentLocation[lIdx+128] = scannedHistogram[(lIdx+128)*cb.m_numGroups+GET_GROUP_IDX];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" int keys[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" keys[0] = (datas[0].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" keys[1] = (datas[1].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" keys[2] = (datas[2].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
" keys[3] = (datas[3].m_key >> cb.m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int dst[NUM_PER_WI];\n"
|
||||
" for(int i=0; i<WG_SIZE; i++)\n"
|
||||
" {\n"
|
||||
" if( i==lIdx )\n"
|
||||
" {\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[0]], dst[0]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[1]], dst[1]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[2]], dst[2]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[3]], dst[3]);\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
" sortDataOut[dst[0]] = datas[0];\n"
|
||||
" sortDataOut[dst[1]] = datas[1];\n"
|
||||
" sortDataOut[dst[2]] = datas[2];\n"
|
||||
" sortDataOut[dst[3]] = datas[3];\n"
|
||||
"}\n"
|
||||
;
|
||||
@@ -0,0 +1,135 @@
|
||||
static const char* radixSortSimpleKernelsDX11= \
|
||||
"/*\n"
|
||||
" 2011 Takahiro Harada\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
"typedef uint u32;\n"
|
||||
"\n"
|
||||
"#define GET_GROUP_IDX groupIdx.x\n"
|
||||
"#define GET_LOCAL_IDX localIdx.x\n"
|
||||
"#define GET_GLOBAL_IDX globalIdx.x\n"
|
||||
"#define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()\n"
|
||||
"#define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID\n"
|
||||
"#define AtomInc(x) InterlockedAdd(x, 1)\n"
|
||||
"#define AtomInc1(x, out) InterlockedAdd(x, 1, out)\n"
|
||||
"\n"
|
||||
"// takahiro end\n"
|
||||
"#define WG_SIZE 128\n"
|
||||
"#define NUM_PER_WI 4\n"
|
||||
"\n"
|
||||
"#define GET_GROUP_SIZE WG_SIZE\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_key; \n"
|
||||
" u32 m_value;\n"
|
||||
"}SortData;\n"
|
||||
"\n"
|
||||
"cbuffer SortCB : register( b0 )\n"
|
||||
"{\n"
|
||||
" u32 m_startBit;\n"
|
||||
" u32 m_numGroups;\n"
|
||||
" u32 m_padding[2];\n"
|
||||
"};\n"
|
||||
" \n"
|
||||
"StructuredBuffer<SortData> sortData : register( t0 );\n"
|
||||
"RWStructuredBuffer<u32> ldsHistogramOut : register( u0 );\n"
|
||||
"\n"
|
||||
"groupshared u32 ldsHistogram[16][256];\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void LocalCountKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
" \n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" ldsHistogram[i][lIdx] = 0.f;\n"
|
||||
" ldsHistogram[i][lIdx+128] = 0.f;\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" datas[0].m_key = (datas[0].m_key >> m_startBit) & 0xff;\n"
|
||||
" datas[1].m_key = (datas[1].m_key >> m_startBit) & 0xff;\n"
|
||||
" datas[2].m_key = (datas[2].m_key >> m_startBit) & 0xff;\n"
|
||||
" datas[3].m_key = (datas[3].m_key >> m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int tableIdx = lIdx%16;\n"
|
||||
" \n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[0].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[1].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[2].m_key]);\n"
|
||||
" AtomInc(ldsHistogram[tableIdx][datas[3].m_key]);\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" u32 sum0, sum1;\n"
|
||||
" sum0 = sum1 = 0;\n"
|
||||
" for(int i=0; i<16; i++)\n"
|
||||
" {\n"
|
||||
" sum0 += ldsHistogram[i][lIdx];\n"
|
||||
" sum1 += ldsHistogram[i][lIdx+128];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" ldsHistogramOut[lIdx*m_numGroups+GET_GROUP_IDX] = sum0;\n"
|
||||
" ldsHistogramOut[(lIdx+128)*m_numGroups+GET_GROUP_IDX] = sum1;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"RWStructuredBuffer<SortData> sortDataOut : register( u0 );\n"
|
||||
"RWStructuredBuffer<u32> scannedHistogram : register( u1 );\n"
|
||||
"\n"
|
||||
"groupshared u32 ldsCurrentLocation[256];\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void ScatterKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" int lIdx = GET_LOCAL_IDX;\n"
|
||||
" int gIdx = GET_GLOBAL_IDX;\n"
|
||||
" \n"
|
||||
" {\n"
|
||||
" ldsCurrentLocation[lIdx] = scannedHistogram[lIdx*m_numGroups+GET_GROUP_IDX];\n"
|
||||
" ldsCurrentLocation[lIdx+128] = scannedHistogram[(lIdx+128)*m_numGroups+GET_GROUP_IDX];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" SortData datas[NUM_PER_WI];\n"
|
||||
" int keys[NUM_PER_WI];\n"
|
||||
" datas[0] = sortData[gIdx*NUM_PER_WI+0];\n"
|
||||
" datas[1] = sortData[gIdx*NUM_PER_WI+1];\n"
|
||||
" datas[2] = sortData[gIdx*NUM_PER_WI+2];\n"
|
||||
" datas[3] = sortData[gIdx*NUM_PER_WI+3];\n"
|
||||
"\n"
|
||||
" keys[0] = (datas[0].m_key >> m_startBit) & 0xff;\n"
|
||||
" keys[1] = (datas[1].m_key >> m_startBit) & 0xff;\n"
|
||||
" keys[2] = (datas[2].m_key >> m_startBit) & 0xff;\n"
|
||||
" keys[3] = (datas[3].m_key >> m_startBit) & 0xff;\n"
|
||||
"\n"
|
||||
" int dst[NUM_PER_WI];\n"
|
||||
" for(int i=0; i<WG_SIZE; i++)\n"
|
||||
"// for(int i=0; i<m_padding[0]; i++) // to reduce compile time\n"
|
||||
" {\n"
|
||||
" if( i==lIdx )\n"
|
||||
" {\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[0]], dst[0]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[1]], dst[1]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[2]], dst[2]);\n"
|
||||
" AtomInc1(ldsCurrentLocation[keys[3]], dst[3]);\n"
|
||||
" }\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
" sortDataOut[dst[0]] = datas[0];\n"
|
||||
" sortDataOut[dst[1]] = datas[1];\n"
|
||||
" sortDataOut[dst[2]] = datas[2];\n"
|
||||
" sortDataOut[dst[3]] = datas[3];\n"
|
||||
"}\n"
|
||||
;
|
||||
@@ -0,0 +1,177 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#define PATH "..\\..\\opencl\\primitives\\AdlPrimitives\\Sort\\RadixSortStandardKernels"
|
||||
#define KERNEL0 "LocalSortKernel"
|
||||
#define KERNEL1 "ScatterKernel"
|
||||
#define KERNEL2 "CopyKernel"
|
||||
|
||||
#include <AdlPrimitives/Sort/RadixSortStandardKernelsCL.h>
|
||||
#include <AdlPrimitives/Sort/RadixSortStandardKernelsDX11.h>
|
||||
|
||||
template<DeviceType type>
|
||||
class RadixSortStandard : public RadixSortBase
|
||||
{
|
||||
public:
|
||||
typedef Launcher::BufferInfo BufferInfo;
|
||||
|
||||
enum
|
||||
{
|
||||
WG_SIZE = 128,
|
||||
NUM_PER_WI = 4,
|
||||
|
||||
BITS_PER_PASS = 4,
|
||||
};
|
||||
|
||||
struct Data : public RadixSort<type>::Data
|
||||
{
|
||||
Kernel* m_localSortKernel;
|
||||
Kernel* m_scatterKernel;
|
||||
Kernel* m_copyKernel;
|
||||
|
||||
Buffer<u32>* m_workBuffer0;
|
||||
Buffer<u32>* m_workBuffer1;
|
||||
Buffer<u32>* m_workBuffer2;
|
||||
Buffer<SortData>* m_workBuffer3;
|
||||
Buffer<int4>* m_constBuffer[32/BITS_PER_PASS];
|
||||
};
|
||||
|
||||
|
||||
static
|
||||
Data* allocate(const Device* deviceData, int maxSize, Option option = SORT_NORMAL);
|
||||
|
||||
static
|
||||
void deallocate(void* data);
|
||||
|
||||
static
|
||||
void execute(void* data, Buffer<SortData>& inout, int n, int sortBits);
|
||||
};
|
||||
|
||||
template<DeviceType type>
|
||||
typename RadixSortStandard<type>::Data* RadixSortStandard<type>::allocate(const Device* deviceData, int maxSize, Option option)
|
||||
{
|
||||
ADLASSERT( type == deviceData->m_type );
|
||||
|
||||
u32 maxNumGroups = (maxSize+WG_SIZE*NUM_PER_WI-1)/(WG_SIZE*NUM_PER_WI);
|
||||
|
||||
const char* src[] =
|
||||
#if defined(ADL_LOAD_KERNEL_FROM_STRING)
|
||||
{radixSortStandardKernelsCL,radixSortStandardKernelsDX11};
|
||||
// ADLASSERT(0);
|
||||
#else
|
||||
{0,0};
|
||||
#endif
|
||||
|
||||
Data* data = new Data;
|
||||
data->m_option = option;
|
||||
data->m_deviceData = deviceData;
|
||||
|
||||
data->m_localSortKernel = deviceData->getKernel( PATH, KERNEL0, 0, src[type] );
|
||||
data->m_scatterKernel = deviceData->getKernel( PATH, KERNEL1, 0, src[type] );
|
||||
data->m_copyKernel = deviceData->getKernel( PATH, KERNEL2, 0, src[type] );
|
||||
|
||||
// is this correct?
|
||||
data->m_scanData = PrefixScan<type>::allocate( deviceData, maxNumGroups*(1<<BITS_PER_PASS) );
|
||||
|
||||
data->m_workBuffer0 = new Buffer<u32>( deviceData, maxNumGroups*(1<<BITS_PER_PASS) );
|
||||
data->m_workBuffer1 = new Buffer<u32>( deviceData, maxNumGroups*(1<<BITS_PER_PASS) );
|
||||
data->m_workBuffer2 = new Buffer<u32>( deviceData, maxNumGroups*(1<<BITS_PER_PASS) );
|
||||
data->m_workBuffer3 = new Buffer<SortData>( deviceData, maxSize );
|
||||
for(int i=0; i<32/BITS_PER_PASS; i++)
|
||||
data->m_constBuffer[i] = new Buffer<int4>( deviceData, 1, BufferBase::BUFFER_CONST );
|
||||
data->m_maxSize = maxSize;
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template<DeviceType type>
|
||||
void RadixSortStandard<type>::deallocate(void* rawData)
|
||||
{
|
||||
Data* data = (Data*)rawData;
|
||||
|
||||
delete data->m_workBuffer0;
|
||||
delete data->m_workBuffer1;
|
||||
delete data->m_workBuffer2;
|
||||
delete data->m_workBuffer3;
|
||||
for(int i=0; i<32/BITS_PER_PASS; i++)
|
||||
delete data->m_constBuffer[i];
|
||||
|
||||
PrefixScan<type>::deallocate( data->m_scanData );
|
||||
|
||||
delete data;
|
||||
}
|
||||
|
||||
template<DeviceType type>
|
||||
void RadixSortStandard<type>::execute(void* rawData, Buffer<SortData>& inout, int n, int sortBits)
|
||||
{
|
||||
Data* data = (Data*)rawData;
|
||||
|
||||
ADLASSERT( n%512 == 0 );
|
||||
ADLASSERT( n <= data->m_maxSize );
|
||||
ADLASSERT( NUM_PER_WI == 4 );
|
||||
|
||||
Buffer<SortData>* src = BufferUtils::map<type, true>( data->m_deviceData, &inout );
|
||||
Buffer<SortData>* dst = data->m_workBuffer3;
|
||||
|
||||
const Device* deviceData = data->m_deviceData;
|
||||
|
||||
int numGroups = (n+WG_SIZE*NUM_PER_WI-1)/(WG_SIZE*NUM_PER_WI);
|
||||
|
||||
int4 constBuffer;
|
||||
|
||||
int iPass = 0;
|
||||
for(int startBit=0; startBit<sortBits; startBit+=BITS_PER_PASS, iPass++)
|
||||
{
|
||||
constBuffer.x = startBit;
|
||||
constBuffer.y = numGroups;
|
||||
constBuffer.z = WG_SIZE;
|
||||
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src ), BufferInfo( data->m_workBuffer0 ), BufferInfo( data->m_workBuffer1 ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_localSortKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
|
||||
launcher.launch1D( WG_SIZE*numGroups, WG_SIZE );
|
||||
}
|
||||
|
||||
PrefixScan<type>::execute( data->m_scanData, *data->m_workBuffer0, *data->m_workBuffer2, numGroups*(1<<BITS_PER_PASS) );
|
||||
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( data->m_workBuffer2, true ), BufferInfo( data->m_workBuffer1, true ),
|
||||
BufferInfo( dst ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_scatterKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
|
||||
launcher.launch1D( WG_SIZE*numGroups, WG_SIZE );
|
||||
}
|
||||
|
||||
if(0)
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( dst, true ), BufferInfo( src ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_copyKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.launch1D( n, WG_SIZE );
|
||||
}
|
||||
swap2( src, dst );
|
||||
}
|
||||
|
||||
if( src != &inout )
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( dst ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_copyKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.launch1D( n, WG_SIZE );
|
||||
}
|
||||
|
||||
BufferUtils::unmap<true>( src, &inout );
|
||||
}
|
||||
|
||||
#undef PATH
|
||||
#undef KERNEL0
|
||||
#undef KERNEL1
|
||||
#undef KERNEL2
|
||||
@@ -0,0 +1,345 @@
|
||||
/*
|
||||
Bullet Continuous Collision Detection and Physics Library
|
||||
Copyright (c) 2011 Advanced Micro Devices, Inc. http://bulletphysics.org
|
||||
|
||||
This software is provided 'as-is', without any express or implied warranty.
|
||||
In no event will the authors be held liable for any damages arising from the use of this software.
|
||||
Permission is granted to anyone to use this software for any purpose,
|
||||
including commercial applications, and to alter it and redistribute it freely,
|
||||
subject to the following restrictions:
|
||||
|
||||
1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required.
|
||||
2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software.
|
||||
3. This notice may not be removed or altered from any source distribution.
|
||||
*/
|
||||
//Author Takahiro Harada
|
||||
|
||||
|
||||
#pragma OPENCL EXTENSION cl_amd_printf : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable
|
||||
|
||||
typedef unsigned int u32;
|
||||
#define GET_GROUP_IDX get_group_id(0)
|
||||
#define GET_LOCAL_IDX get_local_id(0)
|
||||
#define GET_GLOBAL_IDX get_global_id(0)
|
||||
#define GET_GROUP_SIZE get_local_size(0)
|
||||
#define GROUP_LDS_BARRIER barrier(CLK_LOCAL_MEM_FENCE)
|
||||
#define GROUP_MEM_FENCE mem_fence(CLK_LOCAL_MEM_FENCE)
|
||||
#define AtomInc(x) atom_inc(&(x))
|
||||
#define AtomInc1(x, out) out = atom_inc(&(x))
|
||||
|
||||
#define make_uint4 (uint4)
|
||||
#define make_uint2 (uint2)
|
||||
|
||||
#define SELECT_UINT4( b, a, condition ) select( b,a,condition )
|
||||
|
||||
#define WG_SIZE 128
|
||||
#define NUM_PER_WI 4
|
||||
|
||||
|
||||
typedef struct
|
||||
{
|
||||
u32 m_key;
|
||||
u32 m_value;
|
||||
}SortData;
|
||||
|
||||
|
||||
typedef struct
|
||||
{
|
||||
u32 m_startBit;
|
||||
u32 m_numGroups;
|
||||
u32 m_padding[2];
|
||||
} ConstBuffer;
|
||||
|
||||
#define BITS_PER_PASS 4
|
||||
|
||||
|
||||
|
||||
uint4 prefixScanVector( uint4 data )
|
||||
{
|
||||
data.y += data.x;
|
||||
data.w += data.z;
|
||||
data.z += data.y;
|
||||
data.w += data.y;
|
||||
return data;
|
||||
}
|
||||
|
||||
uint prefixScanVectorEx( uint4* data )
|
||||
{
|
||||
uint4 backup = data[0];
|
||||
data[0].y += data[0].x;
|
||||
data[0].w += data[0].z;
|
||||
data[0].z += data[0].y;
|
||||
data[0].w += data[0].y;
|
||||
uint sum = data[0].w;
|
||||
*data -= backup;
|
||||
return sum;
|
||||
}
|
||||
|
||||
uint4 localPrefixSum128V( uint4 pData, uint lIdx, uint* totalSum, __local u32 sorterSharedMemory[] )
|
||||
{
|
||||
{ // Set data
|
||||
sorterSharedMemory[lIdx] = 0;
|
||||
sorterSharedMemory[lIdx+WG_SIZE] = prefixScanVectorEx( &pData );
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // Prefix sum
|
||||
int idx = 2*lIdx + (WG_SIZE+1);
|
||||
if( lIdx < 64 )
|
||||
{
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-1];
|
||||
GROUP_MEM_FENCE;
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-4];
|
||||
GROUP_MEM_FENCE;
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-8];
|
||||
GROUP_MEM_FENCE;
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-16];
|
||||
GROUP_MEM_FENCE;
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-32];
|
||||
GROUP_MEM_FENCE;
|
||||
sorterSharedMemory[idx] += sorterSharedMemory[idx-64];
|
||||
GROUP_MEM_FENCE;
|
||||
|
||||
sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
*totalSum = sorterSharedMemory[WG_SIZE*2-1];
|
||||
uint addValue = sorterSharedMemory[lIdx+127];
|
||||
return pData + make_uint4(addValue, addValue, addValue, addValue);
|
||||
}
|
||||
|
||||
|
||||
void generateHistogram(u32 lIdx, u32 wgIdx,
|
||||
uint4 sortedData,
|
||||
__local u32 *histogram)
|
||||
{
|
||||
if( lIdx < (1<<BITS_PER_PASS) )
|
||||
{
|
||||
histogram[lIdx] = 0;
|
||||
}
|
||||
|
||||
int mask = ((1<<BITS_PER_PASS)-1);
|
||||
uint4 keys = make_uint4( (sortedData.x)&mask, (sortedData.y)&mask, (sortedData.z)&mask, (sortedData.w)&mask );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
AtomInc( histogram[keys.x] );
|
||||
AtomInc( histogram[keys.y] );
|
||||
AtomInc( histogram[keys.z] );
|
||||
AtomInc( histogram[keys.w] );
|
||||
}
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
__kernel
|
||||
__attribute__((reqd_work_group_size(WG_SIZE,1,1)))
|
||||
void LocalSortKernel(__global SortData* sortDataIn,
|
||||
__global u32* ldsHistogramOut0,
|
||||
__global u32* ldsHistogramOut1,
|
||||
ConstBuffer cb)
|
||||
{
|
||||
|
||||
__local u32 ldsSortData[ WG_SIZE*NUM_PER_WI + 16 ];
|
||||
|
||||
int nElemsPerWG = WG_SIZE*NUM_PER_WI;
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
u32 wgSize = GET_GROUP_SIZE;
|
||||
|
||||
uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
|
||||
|
||||
|
||||
SortData sortData[NUM_PER_WI];
|
||||
|
||||
{
|
||||
u32 offset = nElemsPerWG*wgIdx;
|
||||
sortData[0] = sortDataIn[offset+localAddr.x];
|
||||
sortData[1] = sortDataIn[offset+localAddr.y];
|
||||
sortData[2] = sortDataIn[offset+localAddr.z];
|
||||
sortData[3] = sortDataIn[offset+localAddr.w];
|
||||
}
|
||||
|
||||
int bitIdx = cb.m_startBit;
|
||||
do
|
||||
{
|
||||
// what is this?
|
||||
// if( lIdx == wgSize-1 ) ldsSortData[256] = sortData[3].m_key;
|
||||
u32 mask = (1<<bitIdx);
|
||||
uint4 cmpResult = make_uint4( sortData[0].m_key & mask, sortData[1].m_key & mask, sortData[2].m_key & mask, sortData[3].m_key & mask );
|
||||
uint4 prefixSum = SELECT_UINT4( make_uint4(1,1,1,1), make_uint4(0,0,0,0), cmpResult != make_uint4(0,0,0,0) );
|
||||
u32 total;
|
||||
prefixSum = localPrefixSum128V( prefixSum, lIdx, &total, ldsSortData );
|
||||
|
||||
{
|
||||
uint4 dstAddr = localAddr - prefixSum + make_uint4( total, total, total, total );
|
||||
dstAddr = SELECT_UINT4( prefixSum, dstAddr, cmpResult != make_uint4(0, 0, 0, 0) );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
ldsSortData[dstAddr.x] = sortData[0].m_key;
|
||||
ldsSortData[dstAddr.y] = sortData[1].m_key;
|
||||
ldsSortData[dstAddr.z] = sortData[2].m_key;
|
||||
ldsSortData[dstAddr.w] = sortData[3].m_key;
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
sortData[0].m_key = ldsSortData[localAddr.x];
|
||||
sortData[1].m_key = ldsSortData[localAddr.y];
|
||||
sortData[2].m_key = ldsSortData[localAddr.z];
|
||||
sortData[3].m_key = ldsSortData[localAddr.w];
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
ldsSortData[dstAddr.x] = sortData[0].m_value;
|
||||
ldsSortData[dstAddr.y] = sortData[1].m_value;
|
||||
ldsSortData[dstAddr.z] = sortData[2].m_value;
|
||||
ldsSortData[dstAddr.w] = sortData[3].m_value;
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
sortData[0].m_value = ldsSortData[localAddr.x];
|
||||
sortData[1].m_value = ldsSortData[localAddr.y];
|
||||
sortData[2].m_value = ldsSortData[localAddr.z];
|
||||
sortData[3].m_value = ldsSortData[localAddr.w];
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
}
|
||||
bitIdx ++;
|
||||
}
|
||||
while( bitIdx <(cb.m_startBit+BITS_PER_PASS) );
|
||||
|
||||
{ // generate historgram
|
||||
uint4 localKeys = make_uint4( sortData[0].m_key>>cb.m_startBit, sortData[1].m_key>>cb.m_startBit,
|
||||
sortData[2].m_key>>cb.m_startBit, sortData[3].m_key>>cb.m_startBit );
|
||||
|
||||
generateHistogram( lIdx, wgIdx, localKeys, ldsSortData );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
int nBins = (1<<BITS_PER_PASS);
|
||||
if( lIdx < nBins )
|
||||
{
|
||||
u32 histValues = ldsSortData[lIdx];
|
||||
|
||||
u32 globalAddresses = nBins*wgIdx + lIdx;
|
||||
u32 globalAddressesRadixMajor = cb.m_numGroups*lIdx + wgIdx;
|
||||
|
||||
ldsHistogramOut0[globalAddressesRadixMajor] = histValues;
|
||||
ldsHistogramOut1[globalAddresses] = histValues;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
{ // write
|
||||
u32 offset = nElemsPerWG*wgIdx;
|
||||
uint4 dstAddr = make_uint4(offset+localAddr.x, offset+localAddr.y, offset+localAddr.z, offset+localAddr.w );
|
||||
|
||||
sortDataIn[ dstAddr.x + 0 ] = sortData[0];
|
||||
sortDataIn[ dstAddr.x + 1 ] = sortData[1];
|
||||
sortDataIn[ dstAddr.x + 2 ] = sortData[2];
|
||||
sortDataIn[ dstAddr.x + 3 ] = sortData[3];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
__kernel
|
||||
__attribute__((reqd_work_group_size(WG_SIZE,1,1)))
|
||||
void ScatterKernel(__global SortData *src,
|
||||
__global u32 *histogramGlobalRadixMajor,
|
||||
__global u32 *histogramLocalGroupMajor,
|
||||
__global SortData *dst,
|
||||
ConstBuffer cb)
|
||||
{
|
||||
__local u32 sorterLocalMemory[3*(1<<BITS_PER_PASS)];
|
||||
__local u32 *ldsLocalHistogram = sorterLocalMemory + (1<<BITS_PER_PASS);
|
||||
__local u32 *ldsGlobalHistogram = sorterLocalMemory;
|
||||
|
||||
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
u32 ldsOffset = (1<<BITS_PER_PASS);
|
||||
|
||||
// load and prefix scan local histogram
|
||||
if( lIdx < ((1<<BITS_PER_PASS)/2) )
|
||||
{
|
||||
uint2 myIdx = make_uint2(lIdx, lIdx+8);
|
||||
|
||||
ldsLocalHistogram[ldsOffset+myIdx.x] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.x];
|
||||
ldsLocalHistogram[ldsOffset+myIdx.y] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.y];
|
||||
ldsLocalHistogram[ldsOffset+myIdx.x-(1<<BITS_PER_PASS)] = 0;
|
||||
ldsLocalHistogram[ldsOffset+myIdx.y-(1<<BITS_PER_PASS)] = 0;
|
||||
|
||||
int idx = ldsOffset+2*lIdx;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-1];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-4];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-8];
|
||||
GROUP_MEM_FENCE;
|
||||
|
||||
// Propagate intermediate values through
|
||||
ldsLocalHistogram[idx-1] += ldsLocalHistogram[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
|
||||
// Grab and propagate for whole WG - loading the - 1 value
|
||||
uint2 localValues;
|
||||
localValues.x = ldsLocalHistogram[ldsOffset+myIdx.x-1];
|
||||
localValues.y = ldsLocalHistogram[ldsOffset+myIdx.y-1];
|
||||
|
||||
ldsLocalHistogram[myIdx.x] = localValues.x;
|
||||
ldsLocalHistogram[myIdx.y] = localValues.y;
|
||||
|
||||
|
||||
ldsGlobalHistogram[myIdx.x] = histogramGlobalRadixMajor[cb.m_numGroups*myIdx.x + wgIdx];
|
||||
ldsGlobalHistogram[myIdx.y] = histogramGlobalRadixMajor[cb.m_numGroups*myIdx.y + wgIdx];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
|
||||
|
||||
SortData sortData[4];
|
||||
{
|
||||
uint4 globalAddr = wgIdx*WG_SIZE*NUM_PER_WI + localAddr;
|
||||
sortData[0] = src[globalAddr.x];
|
||||
sortData[1] = src[globalAddr.y];
|
||||
sortData[2] = src[globalAddr.z];
|
||||
sortData[3] = src[globalAddr.w];
|
||||
}
|
||||
|
||||
uint cmpValue = ((1<<BITS_PER_PASS)-1);
|
||||
uint4 radix = make_uint4( (sortData[0].m_key>>cb.m_startBit)&cmpValue, (sortData[1].m_key>>cb.m_startBit)&cmpValue,
|
||||
(sortData[2].m_key>>cb.m_startBit)&cmpValue, (sortData[3].m_key>>cb.m_startBit)&cmpValue );;
|
||||
|
||||
// data is already sorted. So simply subtract local prefix sum
|
||||
uint4 dstAddr;
|
||||
dstAddr.x = ldsGlobalHistogram[radix.x] + (localAddr.x - ldsLocalHistogram[radix.x]);
|
||||
dstAddr.y = ldsGlobalHistogram[radix.y] + (localAddr.y - ldsLocalHistogram[radix.y]);
|
||||
dstAddr.z = ldsGlobalHistogram[radix.z] + (localAddr.z - ldsLocalHistogram[radix.z]);
|
||||
dstAddr.w = ldsGlobalHistogram[radix.w] + (localAddr.w - ldsLocalHistogram[radix.w]);
|
||||
|
||||
dst[dstAddr.x] = sortData[0];
|
||||
dst[dstAddr.y] = sortData[1];
|
||||
dst[dstAddr.z] = sortData[2];
|
||||
dst[dstAddr.w] = sortData[3];
|
||||
}
|
||||
|
||||
__kernel
|
||||
__attribute__((reqd_work_group_size(WG_SIZE,1,1)))
|
||||
void CopyKernel(__global SortData *src, __global SortData *dst)
|
||||
{
|
||||
dst[ GET_GLOBAL_IDX ] = src[ GET_GLOBAL_IDX ];
|
||||
}
|
||||
@@ -0,0 +1,322 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
typedef uint u32;
|
||||
|
||||
#define GET_GROUP_IDX groupIdx.x
|
||||
#define GET_LOCAL_IDX localIdx.x
|
||||
#define GET_GLOBAL_IDX globalIdx.x
|
||||
#define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()
|
||||
#define GROUP_MEM_FENCE
|
||||
#define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID
|
||||
#define AtomInc(x) InterlockedAdd(x, 1)
|
||||
#define AtomInc1(x, out) InterlockedAdd(x, 1, out)
|
||||
|
||||
#define make_uint4 uint4
|
||||
#define make_uint2 uint2
|
||||
|
||||
uint4 SELECT_UINT4(uint4 b,uint4 a,uint4 condition ){ return make_uint4( ((condition).x)?a.x:b.x, ((condition).y)?a.y:b.y, ((condition).z)?a.z:b.z, ((condition).w)?a.w:b.w ); }
|
||||
|
||||
// takahiro end
|
||||
#define WG_SIZE 128
|
||||
#define NUM_PER_WI 4
|
||||
|
||||
#define GET_GROUP_SIZE WG_SIZE
|
||||
|
||||
typedef struct
|
||||
{
|
||||
u32 m_key;
|
||||
u32 m_value;
|
||||
}SortData;
|
||||
|
||||
cbuffer SortCB : register( b0 )
|
||||
{
|
||||
u32 m_startBit;
|
||||
u32 m_numGroups;
|
||||
u32 m_padding[2];
|
||||
};
|
||||
|
||||
#define BITS_PER_PASS 4
|
||||
|
||||
|
||||
uint4 prefixScanVector( uint4 data )
|
||||
{
|
||||
data.y += data.x;
|
||||
data.w += data.z;
|
||||
data.z += data.y;
|
||||
data.w += data.y;
|
||||
return data;
|
||||
}
|
||||
|
||||
uint prefixScanVectorEx( inout uint4 data )
|
||||
{
|
||||
uint4 backup = data;
|
||||
data.y += data.x;
|
||||
data.w += data.z;
|
||||
data.z += data.y;
|
||||
data.w += data.y;
|
||||
uint sum = data.w;
|
||||
data -= backup;
|
||||
return sum;
|
||||
}
|
||||
|
||||
|
||||
|
||||
RWStructuredBuffer<SortData> sortDataIn : register( u0 );
|
||||
RWStructuredBuffer<u32> ldsHistogramOut0 : register( u1 );
|
||||
RWStructuredBuffer<u32> ldsHistogramOut1 : register( u2 );
|
||||
|
||||
groupshared u32 ldsSortData[ WG_SIZE*NUM_PER_WI + 16 ];
|
||||
|
||||
|
||||
uint4 localPrefixSum128V( uint4 pData, uint lIdx, inout uint totalSum )
|
||||
{
|
||||
{ // Set data
|
||||
ldsSortData[lIdx] = 0;
|
||||
ldsSortData[lIdx+WG_SIZE] = prefixScanVectorEx( pData );
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
{ // Prefix sum
|
||||
int idx = 2*lIdx + (WG_SIZE+1);
|
||||
if( lIdx < 64 )
|
||||
{
|
||||
ldsSortData[idx] += ldsSortData[idx-1];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsSortData[idx] += ldsSortData[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsSortData[idx] += ldsSortData[idx-4];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsSortData[idx] += ldsSortData[idx-8];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsSortData[idx] += ldsSortData[idx-16];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsSortData[idx] += ldsSortData[idx-32];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsSortData[idx] += ldsSortData[idx-64];
|
||||
GROUP_MEM_FENCE;
|
||||
|
||||
ldsSortData[idx-1] += ldsSortData[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
}
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
totalSum = ldsSortData[WG_SIZE*2-1];
|
||||
uint addValue = ldsSortData[lIdx+127];
|
||||
return pData + make_uint4(addValue, addValue, addValue, addValue);
|
||||
}
|
||||
|
||||
void generateHistogram(u32 lIdx, u32 wgIdx,
|
||||
uint4 sortedData)
|
||||
{
|
||||
if( lIdx < (1<<BITS_PER_PASS) )
|
||||
{
|
||||
ldsSortData[lIdx] = 0;
|
||||
}
|
||||
|
||||
int mask = ((1<<BITS_PER_PASS)-1);
|
||||
uint4 keys = make_uint4( (sortedData.x)&mask, (sortedData.y)&mask, (sortedData.z)&mask, (sortedData.w)&mask );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
AtomInc( ldsSortData[keys.x] );
|
||||
AtomInc( ldsSortData[keys.y] );
|
||||
AtomInc( ldsSortData[keys.z] );
|
||||
AtomInc( ldsSortData[keys.w] );
|
||||
}
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void LocalSortKernel( DEFAULT_ARGS )
|
||||
{
|
||||
int nElemsPerWG = WG_SIZE*NUM_PER_WI;
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
u32 wgSize = GET_GROUP_SIZE;
|
||||
|
||||
uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
|
||||
|
||||
|
||||
SortData sortData[NUM_PER_WI];
|
||||
|
||||
{
|
||||
u32 offset = nElemsPerWG*wgIdx;
|
||||
sortData[0] = sortDataIn[offset+localAddr.x];
|
||||
sortData[1] = sortDataIn[offset+localAddr.y];
|
||||
sortData[2] = sortDataIn[offset+localAddr.z];
|
||||
sortData[3] = sortDataIn[offset+localAddr.w];
|
||||
}
|
||||
|
||||
int bitIdx = m_startBit;
|
||||
do
|
||||
{
|
||||
// what is this?
|
||||
// if( lIdx == wgSize-1 ) ldsSortData[256] = sortData[3].m_key;
|
||||
u32 mask = (1<<bitIdx);
|
||||
uint4 cmpResult = make_uint4( sortData[0].m_key & mask, sortData[1].m_key & mask, sortData[2].m_key & mask, sortData[3].m_key & mask );
|
||||
uint4 prefixSum = SELECT_UINT4( make_uint4(1,1,1,1), make_uint4(0,0,0,0), cmpResult != make_uint4(0,0,0,0) );
|
||||
u32 total;
|
||||
prefixSum = localPrefixSum128V( prefixSum, lIdx, total );
|
||||
|
||||
{
|
||||
uint4 dstAddr = localAddr - prefixSum + make_uint4( total, total, total, total );
|
||||
dstAddr = SELECT_UINT4( prefixSum, dstAddr, cmpResult != make_uint4(0, 0, 0, 0) );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
ldsSortData[dstAddr.x] = sortData[0].m_key;
|
||||
ldsSortData[dstAddr.y] = sortData[1].m_key;
|
||||
ldsSortData[dstAddr.z] = sortData[2].m_key;
|
||||
ldsSortData[dstAddr.w] = sortData[3].m_key;
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
sortData[0].m_key = ldsSortData[localAddr.x];
|
||||
sortData[1].m_key = ldsSortData[localAddr.y];
|
||||
sortData[2].m_key = ldsSortData[localAddr.z];
|
||||
sortData[3].m_key = ldsSortData[localAddr.w];
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
ldsSortData[dstAddr.x] = sortData[0].m_value;
|
||||
ldsSortData[dstAddr.y] = sortData[1].m_value;
|
||||
ldsSortData[dstAddr.z] = sortData[2].m_value;
|
||||
ldsSortData[dstAddr.w] = sortData[3].m_value;
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
sortData[0].m_value = ldsSortData[localAddr.x];
|
||||
sortData[1].m_value = ldsSortData[localAddr.y];
|
||||
sortData[2].m_value = ldsSortData[localAddr.z];
|
||||
sortData[3].m_value = ldsSortData[localAddr.w];
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
}
|
||||
bitIdx ++;
|
||||
}
|
||||
while( bitIdx <(m_startBit+BITS_PER_PASS) );
|
||||
|
||||
{ // generate historgram
|
||||
uint4 localKeys = make_uint4( sortData[0].m_key>>m_startBit, sortData[1].m_key>>m_startBit,
|
||||
sortData[2].m_key>>m_startBit, sortData[3].m_key>>m_startBit );
|
||||
|
||||
generateHistogram( lIdx, wgIdx, localKeys );
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
int nBins = (1<<BITS_PER_PASS);
|
||||
if( lIdx < nBins )
|
||||
{
|
||||
u32 histValues = ldsSortData[lIdx];
|
||||
|
||||
u32 globalAddresses = nBins*wgIdx + lIdx;
|
||||
u32 globalAddressesRadixMajor = m_numGroups*lIdx + wgIdx;
|
||||
|
||||
ldsHistogramOut0[globalAddressesRadixMajor] = histValues;
|
||||
ldsHistogramOut1[globalAddresses] = histValues;
|
||||
}
|
||||
}
|
||||
|
||||
{ // write
|
||||
u32 offset = nElemsPerWG*wgIdx;
|
||||
uint4 dstAddr = make_uint4(offset+localAddr.x, offset+localAddr.y, offset+localAddr.z, offset+localAddr.w );
|
||||
|
||||
sortDataIn[ dstAddr.x + 0 ] = sortData[0];
|
||||
sortDataIn[ dstAddr.x + 1 ] = sortData[1];
|
||||
sortDataIn[ dstAddr.x + 2 ] = sortData[2];
|
||||
sortDataIn[ dstAddr.x + 3 ] = sortData[3];
|
||||
}
|
||||
}
|
||||
|
||||
StructuredBuffer<SortData> src : register( t0 );
|
||||
StructuredBuffer<u32> histogramGlobalRadixMajor : register( t1 );
|
||||
StructuredBuffer<u32> histogramLocalGroupMajor : register( t2 );
|
||||
|
||||
RWStructuredBuffer<SortData> dst : register( u0 );
|
||||
|
||||
groupshared u32 ldsLocalHistogram[ 2*(1<<BITS_PER_PASS) ];
|
||||
groupshared u32 ldsGlobalHistogram[ (1<<BITS_PER_PASS) ];
|
||||
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void ScatterKernel( DEFAULT_ARGS )
|
||||
{
|
||||
u32 lIdx = GET_LOCAL_IDX;
|
||||
u32 wgIdx = GET_GROUP_IDX;
|
||||
u32 ldsOffset = (1<<BITS_PER_PASS);
|
||||
|
||||
// load and prefix scan local histogram
|
||||
if( lIdx < ((1<<BITS_PER_PASS)/2) )
|
||||
{
|
||||
uint2 myIdx = make_uint2(lIdx, lIdx+8);
|
||||
|
||||
ldsLocalHistogram[ldsOffset+myIdx.x] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.x];
|
||||
ldsLocalHistogram[ldsOffset+myIdx.y] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.y];
|
||||
ldsLocalHistogram[ldsOffset+myIdx.x-(1<<BITS_PER_PASS)] = 0;
|
||||
ldsLocalHistogram[ldsOffset+myIdx.y-(1<<BITS_PER_PASS)] = 0;
|
||||
|
||||
int idx = ldsOffset+2*lIdx;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-1];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-4];
|
||||
GROUP_MEM_FENCE;
|
||||
ldsLocalHistogram[idx] += ldsLocalHistogram[idx-8];
|
||||
GROUP_MEM_FENCE;
|
||||
|
||||
// Propagate intermediate values through
|
||||
ldsLocalHistogram[idx-1] += ldsLocalHistogram[idx-2];
|
||||
GROUP_MEM_FENCE;
|
||||
|
||||
// Grab and propagate for whole WG - loading the - 1 value
|
||||
uint2 localValues;
|
||||
localValues.x = ldsLocalHistogram[ldsOffset+myIdx.x-1];
|
||||
localValues.y = ldsLocalHistogram[ldsOffset+myIdx.y-1];
|
||||
|
||||
ldsLocalHistogram[myIdx.x] = localValues.x;
|
||||
ldsLocalHistogram[myIdx.y] = localValues.y;
|
||||
|
||||
|
||||
ldsGlobalHistogram[myIdx.x] = histogramGlobalRadixMajor[m_numGroups*myIdx.x + wgIdx];
|
||||
ldsGlobalHistogram[myIdx.y] = histogramGlobalRadixMajor[m_numGroups*myIdx.y + wgIdx];
|
||||
}
|
||||
|
||||
GROUP_LDS_BARRIER;
|
||||
|
||||
uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
|
||||
|
||||
SortData sortData[4];
|
||||
{
|
||||
uint4 globalAddr = wgIdx*WG_SIZE*NUM_PER_WI + localAddr;
|
||||
sortData[0] = src[globalAddr.x];
|
||||
sortData[1] = src[globalAddr.y];
|
||||
sortData[2] = src[globalAddr.z];
|
||||
sortData[3] = src[globalAddr.w];
|
||||
}
|
||||
|
||||
uint cmpValue = ((1<<BITS_PER_PASS)-1);
|
||||
uint4 radix = make_uint4( (sortData[0].m_key>>m_startBit)&cmpValue, (sortData[1].m_key>>m_startBit)&cmpValue,
|
||||
(sortData[2].m_key>>m_startBit)&cmpValue, (sortData[3].m_key>>m_startBit)&cmpValue );;
|
||||
|
||||
// data is already sorted. So simply subtract local prefix sum
|
||||
uint4 dstAddr;
|
||||
dstAddr.x = ldsGlobalHistogram[radix.x] + (localAddr.x - ldsLocalHistogram[radix.x]);
|
||||
dstAddr.y = ldsGlobalHistogram[radix.y] + (localAddr.y - ldsLocalHistogram[radix.y]);
|
||||
dstAddr.z = ldsGlobalHistogram[radix.z] + (localAddr.z - ldsLocalHistogram[radix.z]);
|
||||
dstAddr.w = ldsGlobalHistogram[radix.w] + (localAddr.w - ldsLocalHistogram[radix.w]);
|
||||
|
||||
dst[dstAddr.x] = sortData[0];
|
||||
dst[dstAddr.y] = sortData[1];
|
||||
dst[dstAddr.z] = sortData[2];
|
||||
dst[dstAddr.w] = sortData[3];
|
||||
}
|
||||
|
||||
[numthreads(WG_SIZE, 1, 1)]
|
||||
void CopyKernel( DEFAULT_ARGS )
|
||||
{
|
||||
dst[ GET_GLOBAL_IDX ] = src[ GET_GLOBAL_IDX ];
|
||||
}
|
||||
@@ -0,0 +1,347 @@
|
||||
static const char* radixSortStandardKernelsCL= \
|
||||
"/*\n"
|
||||
"Bullet Continuous Collision Detection and Physics Library\n"
|
||||
"Copyright (c) 2011 Advanced Micro Devices, Inc. http://bulletphysics.org\n"
|
||||
"\n"
|
||||
"This software is provided 'as-is', without any express or implied warranty.\n"
|
||||
"In no event will the authors be held liable for any damages arising from the use of this software.\n"
|
||||
"Permission is granted to anyone to use this software for any purpose, \n"
|
||||
"including commercial applications, and to alter it and redistribute it freely, \n"
|
||||
"subject to the following restrictions:\n"
|
||||
"\n"
|
||||
"1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required.\n"
|
||||
"2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software.\n"
|
||||
"3. This notice may not be removed or altered from any source distribution.\n"
|
||||
"*/\n"
|
||||
"//Author Takahiro Harada\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"#pragma OPENCL EXTENSION cl_amd_printf : enable\n"
|
||||
"#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable\n"
|
||||
"\n"
|
||||
"typedef unsigned int u32;\n"
|
||||
"#define GET_GROUP_IDX get_group_id(0)\n"
|
||||
"#define GET_LOCAL_IDX get_local_id(0)\n"
|
||||
"#define GET_GLOBAL_IDX get_global_id(0)\n"
|
||||
"#define GET_GROUP_SIZE get_local_size(0)\n"
|
||||
"#define GROUP_LDS_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n"
|
||||
"#define GROUP_MEM_FENCE mem_fence(CLK_LOCAL_MEM_FENCE)\n"
|
||||
"#define AtomInc(x) atom_inc(&(x))\n"
|
||||
"#define AtomInc1(x, out) out = atom_inc(&(x))\n"
|
||||
"\n"
|
||||
"#define make_uint4 (uint4)\n"
|
||||
"#define make_uint2 (uint2)\n"
|
||||
"\n"
|
||||
"#define SELECT_UINT4( b, a, condition ) select( b,a,condition )\n"
|
||||
"\n"
|
||||
"#define WG_SIZE 128\n"
|
||||
"#define NUM_PER_WI 4\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_key; \n"
|
||||
" u32 m_value;\n"
|
||||
"}SortData;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_startBit;\n"
|
||||
" u32 m_numGroups;\n"
|
||||
" u32 m_padding[2];\n"
|
||||
"} ConstBuffer;\n"
|
||||
"\n"
|
||||
"#define BITS_PER_PASS 4\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"uint4 prefixScanVector( uint4 data )\n"
|
||||
"{\n"
|
||||
" data.y += data.x;\n"
|
||||
" data.w += data.z;\n"
|
||||
" data.z += data.y;\n"
|
||||
" data.w += data.y;\n"
|
||||
" return data;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint prefixScanVectorEx( uint4* data )\n"
|
||||
"{\n"
|
||||
" uint4 backup = data[0];\n"
|
||||
" data[0].y += data[0].x;\n"
|
||||
" data[0].w += data[0].z;\n"
|
||||
" data[0].z += data[0].y;\n"
|
||||
" data[0].w += data[0].y;\n"
|
||||
" uint sum = data[0].w;\n"
|
||||
" *data -= backup;\n"
|
||||
" return sum;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint4 localPrefixSum128V( uint4 pData, uint lIdx, uint* totalSum, __local u32 sorterSharedMemory[] )\n"
|
||||
"{\n"
|
||||
" { // Set data\n"
|
||||
" sorterSharedMemory[lIdx] = 0;\n"
|
||||
" sorterSharedMemory[lIdx+WG_SIZE] = prefixScanVectorEx( &pData );\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // Prefix sum\n"
|
||||
" int idx = 2*lIdx + (WG_SIZE+1);\n"
|
||||
" if( lIdx < 64 )\n"
|
||||
" {\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-1];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-2]; \n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-4];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-8];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-16];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-32]; \n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" sorterSharedMemory[idx] += sorterSharedMemory[idx-64];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
"\n"
|
||||
" sorterSharedMemory[idx-1] += sorterSharedMemory[idx-2];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" *totalSum = sorterSharedMemory[WG_SIZE*2-1];\n"
|
||||
" uint addValue = sorterSharedMemory[lIdx+127];\n"
|
||||
" return pData + make_uint4(addValue, addValue, addValue, addValue);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"void generateHistogram(u32 lIdx, u32 wgIdx, \n"
|
||||
" uint4 sortedData,\n"
|
||||
" __local u32 *histogram)\n"
|
||||
"{\n"
|
||||
" if( lIdx < (1<<BITS_PER_PASS) )\n"
|
||||
" {\n"
|
||||
" histogram[lIdx] = 0;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" int mask = ((1<<BITS_PER_PASS)-1);\n"
|
||||
" uint4 keys = make_uint4( (sortedData.x)&mask, (sortedData.y)&mask, (sortedData.z)&mask, (sortedData.w)&mask );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" AtomInc( histogram[keys.x] );\n"
|
||||
" AtomInc( histogram[keys.y] );\n"
|
||||
" AtomInc( histogram[keys.z] );\n"
|
||||
" AtomInc( histogram[keys.w] );\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"//\n"
|
||||
"//\n"
|
||||
"//\n"
|
||||
"\n"
|
||||
"__kernel\n"
|
||||
"__attribute__((reqd_work_group_size(WG_SIZE,1,1)))\n"
|
||||
"void LocalSortKernel(__global SortData* sortDataIn, \n"
|
||||
" __global u32* ldsHistogramOut0,\n"
|
||||
" __global u32* ldsHistogramOut1,\n"
|
||||
" ConstBuffer cb)\n"
|
||||
"{\n"
|
||||
"\n"
|
||||
" __local u32 ldsSortData[ WG_SIZE*NUM_PER_WI + 16 ];\n"
|
||||
"\n"
|
||||
" int nElemsPerWG = WG_SIZE*NUM_PER_WI;\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
" u32 wgSize = GET_GROUP_SIZE;\n"
|
||||
"\n"
|
||||
" uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" SortData sortData[NUM_PER_WI];\n"
|
||||
"\n"
|
||||
" {\n"
|
||||
" u32 offset = nElemsPerWG*wgIdx;\n"
|
||||
" sortData[0] = sortDataIn[offset+localAddr.x];\n"
|
||||
" sortData[1] = sortDataIn[offset+localAddr.y];\n"
|
||||
" sortData[2] = sortDataIn[offset+localAddr.z];\n"
|
||||
" sortData[3] = sortDataIn[offset+localAddr.w];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" int bitIdx = cb.m_startBit;\n"
|
||||
" do\n"
|
||||
" {\n"
|
||||
"// what is this?\n"
|
||||
"// if( lIdx == wgSize-1 ) ldsSortData[256] = sortData[3].m_key;\n"
|
||||
" u32 mask = (1<<bitIdx);\n"
|
||||
" uint4 cmpResult = make_uint4( sortData[0].m_key & mask, sortData[1].m_key & mask, sortData[2].m_key & mask, sortData[3].m_key & mask );\n"
|
||||
" uint4 prefixSum = SELECT_UINT4( make_uint4(1,1,1,1), make_uint4(0,0,0,0), cmpResult != make_uint4(0,0,0,0) );\n"
|
||||
" u32 total;\n"
|
||||
" prefixSum = localPrefixSum128V( prefixSum, lIdx, &total, ldsSortData );\n"
|
||||
"\n"
|
||||
" {\n"
|
||||
" uint4 dstAddr = localAddr - prefixSum + make_uint4( total, total, total, total );\n"
|
||||
" dstAddr = SELECT_UINT4( prefixSum, dstAddr, cmpResult != make_uint4(0, 0, 0, 0) );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" ldsSortData[dstAddr.x] = sortData[0].m_key;\n"
|
||||
" ldsSortData[dstAddr.y] = sortData[1].m_key;\n"
|
||||
" ldsSortData[dstAddr.z] = sortData[2].m_key;\n"
|
||||
" ldsSortData[dstAddr.w] = sortData[3].m_key;\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" sortData[0].m_key = ldsSortData[localAddr.x];\n"
|
||||
" sortData[1].m_key = ldsSortData[localAddr.y];\n"
|
||||
" sortData[2].m_key = ldsSortData[localAddr.z];\n"
|
||||
" sortData[3].m_key = ldsSortData[localAddr.w];\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" ldsSortData[dstAddr.x] = sortData[0].m_value;\n"
|
||||
" ldsSortData[dstAddr.y] = sortData[1].m_value;\n"
|
||||
" ldsSortData[dstAddr.z] = sortData[2].m_value;\n"
|
||||
" ldsSortData[dstAddr.w] = sortData[3].m_value;\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" sortData[0].m_value = ldsSortData[localAddr.x];\n"
|
||||
" sortData[1].m_value = ldsSortData[localAddr.y];\n"
|
||||
" sortData[2].m_value = ldsSortData[localAddr.z];\n"
|
||||
" sortData[3].m_value = ldsSortData[localAddr.w];\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
" bitIdx ++;\n"
|
||||
" }\n"
|
||||
" while( bitIdx <(cb.m_startBit+BITS_PER_PASS) );\n"
|
||||
"\n"
|
||||
" { // generate historgram\n"
|
||||
" uint4 localKeys = make_uint4( sortData[0].m_key>>cb.m_startBit, sortData[1].m_key>>cb.m_startBit, \n"
|
||||
" sortData[2].m_key>>cb.m_startBit, sortData[3].m_key>>cb.m_startBit );\n"
|
||||
"\n"
|
||||
" generateHistogram( lIdx, wgIdx, localKeys, ldsSortData );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" int nBins = (1<<BITS_PER_PASS);\n"
|
||||
" if( lIdx < nBins )\n"
|
||||
" {\n"
|
||||
" u32 histValues = ldsSortData[lIdx];\n"
|
||||
"\n"
|
||||
" u32 globalAddresses = nBins*wgIdx + lIdx;\n"
|
||||
" u32 globalAddressesRadixMajor = cb.m_numGroups*lIdx + wgIdx;\n"
|
||||
" \n"
|
||||
" ldsHistogramOut0[globalAddressesRadixMajor] = histValues;\n"
|
||||
" ldsHistogramOut1[globalAddresses] = histValues;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" { // write\n"
|
||||
" u32 offset = nElemsPerWG*wgIdx;\n"
|
||||
" uint4 dstAddr = make_uint4(offset+localAddr.x, offset+localAddr.y, offset+localAddr.z, offset+localAddr.w );\n"
|
||||
"\n"
|
||||
" sortDataIn[ dstAddr.x + 0 ] = sortData[0];\n"
|
||||
" sortDataIn[ dstAddr.x + 1 ] = sortData[1];\n"
|
||||
" sortDataIn[ dstAddr.x + 2 ] = sortData[2];\n"
|
||||
" sortDataIn[ dstAddr.x + 3 ] = sortData[3];\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"__kernel\n"
|
||||
"__attribute__((reqd_work_group_size(WG_SIZE,1,1)))\n"
|
||||
"void ScatterKernel(__global SortData *src,\n"
|
||||
" __global u32 *histogramGlobalRadixMajor,\n"
|
||||
" __global u32 *histogramLocalGroupMajor,\n"
|
||||
" __global SortData *dst,\n"
|
||||
" ConstBuffer cb)\n"
|
||||
"{\n"
|
||||
" __local u32 sorterLocalMemory[3*(1<<BITS_PER_PASS)];\n"
|
||||
" __local u32 *ldsLocalHistogram = sorterLocalMemory + (1<<BITS_PER_PASS);\n"
|
||||
" __local u32 *ldsGlobalHistogram = sorterLocalMemory;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
" u32 ldsOffset = (1<<BITS_PER_PASS);\n"
|
||||
"\n"
|
||||
" // load and prefix scan local histogram\n"
|
||||
" if( lIdx < ((1<<BITS_PER_PASS)/2) )\n"
|
||||
" {\n"
|
||||
" uint2 myIdx = make_uint2(lIdx, lIdx+8);\n"
|
||||
"\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.x] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.x];\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.y] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.y];\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.x-(1<<BITS_PER_PASS)] = 0;\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.y-(1<<BITS_PER_PASS)] = 0;\n"
|
||||
"\n"
|
||||
" int idx = ldsOffset+2*lIdx;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-1];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-2];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-4];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-8];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
"\n"
|
||||
" // Propagate intermediate values through\n"
|
||||
" ldsLocalHistogram[idx-1] += ldsLocalHistogram[idx-2];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
"\n"
|
||||
" // Grab and propagate for whole WG - loading the - 1 value\n"
|
||||
" uint2 localValues;\n"
|
||||
" localValues.x = ldsLocalHistogram[ldsOffset+myIdx.x-1];\n"
|
||||
" localValues.y = ldsLocalHistogram[ldsOffset+myIdx.y-1];\n"
|
||||
"\n"
|
||||
" ldsLocalHistogram[myIdx.x] = localValues.x;\n"
|
||||
" ldsLocalHistogram[myIdx.y] = localValues.y;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" ldsGlobalHistogram[myIdx.x] = histogramGlobalRadixMajor[cb.m_numGroups*myIdx.x + wgIdx];\n"
|
||||
" ldsGlobalHistogram[myIdx.y] = histogramGlobalRadixMajor[cb.m_numGroups*myIdx.y + wgIdx];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);\n"
|
||||
"\n"
|
||||
" SortData sortData[4];\n"
|
||||
" {\n"
|
||||
" uint4 globalAddr = wgIdx*WG_SIZE*NUM_PER_WI + localAddr;\n"
|
||||
" sortData[0] = src[globalAddr.x];\n"
|
||||
" sortData[1] = src[globalAddr.y];\n"
|
||||
" sortData[2] = src[globalAddr.z];\n"
|
||||
" sortData[3] = src[globalAddr.w];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" uint cmpValue = ((1<<BITS_PER_PASS)-1);\n"
|
||||
" uint4 radix = make_uint4( (sortData[0].m_key>>cb.m_startBit)&cmpValue, (sortData[1].m_key>>cb.m_startBit)&cmpValue, \n"
|
||||
" (sortData[2].m_key>>cb.m_startBit)&cmpValue, (sortData[3].m_key>>cb.m_startBit)&cmpValue );;\n"
|
||||
"\n"
|
||||
" // data is already sorted. So simply subtract local prefix sum\n"
|
||||
" uint4 dstAddr;\n"
|
||||
" dstAddr.x = ldsGlobalHistogram[radix.x] + (localAddr.x - ldsLocalHistogram[radix.x]);\n"
|
||||
" dstAddr.y = ldsGlobalHistogram[radix.y] + (localAddr.y - ldsLocalHistogram[radix.y]);\n"
|
||||
" dstAddr.z = ldsGlobalHistogram[radix.z] + (localAddr.z - ldsLocalHistogram[radix.z]);\n"
|
||||
" dstAddr.w = ldsGlobalHistogram[radix.w] + (localAddr.w - ldsLocalHistogram[radix.w]);\n"
|
||||
"\n"
|
||||
" dst[dstAddr.x] = sortData[0];\n"
|
||||
" dst[dstAddr.y] = sortData[1];\n"
|
||||
" dst[dstAddr.z] = sortData[2];\n"
|
||||
" dst[dstAddr.w] = sortData[3];\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel\n"
|
||||
"__attribute__((reqd_work_group_size(WG_SIZE,1,1)))\n"
|
||||
"void CopyKernel(__global SortData *src, __global SortData *dst)\n"
|
||||
"{\n"
|
||||
" dst[ GET_GLOBAL_IDX ] = src[ GET_GLOBAL_IDX ];\n"
|
||||
"}\n"
|
||||
;
|
||||
@@ -0,0 +1,324 @@
|
||||
static const char* radixSortStandardKernelsDX11= \
|
||||
"/*\n"
|
||||
" 2011 Takahiro Harada\n"
|
||||
"*/\n"
|
||||
"\n"
|
||||
"typedef uint u32;\n"
|
||||
"\n"
|
||||
"#define GET_GROUP_IDX groupIdx.x\n"
|
||||
"#define GET_LOCAL_IDX localIdx.x\n"
|
||||
"#define GET_GLOBAL_IDX globalIdx.x\n"
|
||||
"#define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()\n"
|
||||
"#define GROUP_MEM_FENCE\n"
|
||||
"#define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID\n"
|
||||
"#define AtomInc(x) InterlockedAdd(x, 1)\n"
|
||||
"#define AtomInc1(x, out) InterlockedAdd(x, 1, out)\n"
|
||||
"\n"
|
||||
"#define make_uint4 uint4\n"
|
||||
"#define make_uint2 uint2\n"
|
||||
"\n"
|
||||
"uint4 SELECT_UINT4(uint4 b,uint4 a,uint4 condition ){ return make_uint4( ((condition).x)?a.x:b.x, ((condition).y)?a.y:b.y, ((condition).z)?a.z:b.z, ((condition).w)?a.w:b.w ); }\n"
|
||||
"\n"
|
||||
"// takahiro end\n"
|
||||
"#define WG_SIZE 128\n"
|
||||
"#define NUM_PER_WI 4\n"
|
||||
"\n"
|
||||
"#define GET_GROUP_SIZE WG_SIZE\n"
|
||||
"\n"
|
||||
"typedef struct\n"
|
||||
"{\n"
|
||||
" u32 m_key; \n"
|
||||
" u32 m_value;\n"
|
||||
"}SortData;\n"
|
||||
"\n"
|
||||
"cbuffer SortCB : register( b0 )\n"
|
||||
"{\n"
|
||||
" u32 m_startBit;\n"
|
||||
" u32 m_numGroups;\n"
|
||||
" u32 m_padding[2];\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"#define BITS_PER_PASS 4\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"uint4 prefixScanVector( uint4 data )\n"
|
||||
"{\n"
|
||||
" data.y += data.x;\n"
|
||||
" data.w += data.z;\n"
|
||||
" data.z += data.y;\n"
|
||||
" data.w += data.y;\n"
|
||||
" return data;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"uint prefixScanVectorEx( inout uint4 data )\n"
|
||||
"{\n"
|
||||
" uint4 backup = data;\n"
|
||||
" data.y += data.x;\n"
|
||||
" data.w += data.z;\n"
|
||||
" data.z += data.y;\n"
|
||||
" data.w += data.y;\n"
|
||||
" uint sum = data.w;\n"
|
||||
" data -= backup;\n"
|
||||
" return sum;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"RWStructuredBuffer<SortData> sortDataIn : register( u0 );\n"
|
||||
"RWStructuredBuffer<u32> ldsHistogramOut0 : register( u1 );\n"
|
||||
"RWStructuredBuffer<u32> ldsHistogramOut1 : register( u2 );\n"
|
||||
"\n"
|
||||
"groupshared u32 ldsSortData[ WG_SIZE*NUM_PER_WI + 16 ];\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"uint4 localPrefixSum128V( uint4 pData, uint lIdx, inout uint totalSum )\n"
|
||||
"{\n"
|
||||
" { // Set data\n"
|
||||
" ldsSortData[lIdx] = 0;\n"
|
||||
" ldsSortData[lIdx+WG_SIZE] = prefixScanVectorEx( pData );\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" { // Prefix sum\n"
|
||||
" int idx = 2*lIdx + (WG_SIZE+1);\n"
|
||||
" if( lIdx < 64 )\n"
|
||||
" {\n"
|
||||
" ldsSortData[idx] += ldsSortData[idx-1];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsSortData[idx] += ldsSortData[idx-2]; \n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsSortData[idx] += ldsSortData[idx-4];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsSortData[idx] += ldsSortData[idx-8];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsSortData[idx] += ldsSortData[idx-16];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsSortData[idx] += ldsSortData[idx-32]; \n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsSortData[idx] += ldsSortData[idx-64];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
"\n"
|
||||
" ldsSortData[idx-1] += ldsSortData[idx-2];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" totalSum = ldsSortData[WG_SIZE*2-1];\n"
|
||||
" uint addValue = ldsSortData[lIdx+127];\n"
|
||||
" return pData + make_uint4(addValue, addValue, addValue, addValue);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"void generateHistogram(u32 lIdx, u32 wgIdx, \n"
|
||||
" uint4 sortedData)\n"
|
||||
"{\n"
|
||||
" if( lIdx < (1<<BITS_PER_PASS) )\n"
|
||||
" {\n"
|
||||
" ldsSortData[lIdx] = 0;\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" int mask = ((1<<BITS_PER_PASS)-1);\n"
|
||||
" uint4 keys = make_uint4( (sortedData.x)&mask, (sortedData.y)&mask, (sortedData.z)&mask, (sortedData.w)&mask );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" \n"
|
||||
" AtomInc( ldsSortData[keys.x] );\n"
|
||||
" AtomInc( ldsSortData[keys.y] );\n"
|
||||
" AtomInc( ldsSortData[keys.z] );\n"
|
||||
" AtomInc( ldsSortData[keys.w] );\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void LocalSortKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" int nElemsPerWG = WG_SIZE*NUM_PER_WI;\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
" u32 wgSize = GET_GROUP_SIZE;\n"
|
||||
"\n"
|
||||
" uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" SortData sortData[NUM_PER_WI];\n"
|
||||
"\n"
|
||||
" {\n"
|
||||
" u32 offset = nElemsPerWG*wgIdx;\n"
|
||||
" sortData[0] = sortDataIn[offset+localAddr.x];\n"
|
||||
" sortData[1] = sortDataIn[offset+localAddr.y];\n"
|
||||
" sortData[2] = sortDataIn[offset+localAddr.z];\n"
|
||||
" sortData[3] = sortDataIn[offset+localAddr.w];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" int bitIdx = m_startBit;\n"
|
||||
" do\n"
|
||||
" {\n"
|
||||
"// what is this?\n"
|
||||
"// if( lIdx == wgSize-1 ) ldsSortData[256] = sortData[3].m_key;\n"
|
||||
" u32 mask = (1<<bitIdx);\n"
|
||||
" uint4 cmpResult = make_uint4( sortData[0].m_key & mask, sortData[1].m_key & mask, sortData[2].m_key & mask, sortData[3].m_key & mask );\n"
|
||||
" uint4 prefixSum = SELECT_UINT4( make_uint4(1,1,1,1), make_uint4(0,0,0,0), cmpResult != make_uint4(0,0,0,0) );\n"
|
||||
" u32 total;\n"
|
||||
" prefixSum = localPrefixSum128V( prefixSum, lIdx, total );\n"
|
||||
"\n"
|
||||
" {\n"
|
||||
" uint4 dstAddr = localAddr - prefixSum + make_uint4( total, total, total, total );\n"
|
||||
" dstAddr = SELECT_UINT4( prefixSum, dstAddr, cmpResult != make_uint4(0, 0, 0, 0) );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" ldsSortData[dstAddr.x] = sortData[0].m_key;\n"
|
||||
" ldsSortData[dstAddr.y] = sortData[1].m_key;\n"
|
||||
" ldsSortData[dstAddr.z] = sortData[2].m_key;\n"
|
||||
" ldsSortData[dstAddr.w] = sortData[3].m_key;\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" sortData[0].m_key = ldsSortData[localAddr.x];\n"
|
||||
" sortData[1].m_key = ldsSortData[localAddr.y];\n"
|
||||
" sortData[2].m_key = ldsSortData[localAddr.z];\n"
|
||||
" sortData[3].m_key = ldsSortData[localAddr.w];\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" ldsSortData[dstAddr.x] = sortData[0].m_value;\n"
|
||||
" ldsSortData[dstAddr.y] = sortData[1].m_value;\n"
|
||||
" ldsSortData[dstAddr.z] = sortData[2].m_value;\n"
|
||||
" ldsSortData[dstAddr.w] = sortData[3].m_value;\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" sortData[0].m_value = ldsSortData[localAddr.x];\n"
|
||||
" sortData[1].m_value = ldsSortData[localAddr.y];\n"
|
||||
" sortData[2].m_value = ldsSortData[localAddr.z];\n"
|
||||
" sortData[3].m_value = ldsSortData[localAddr.w];\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
" }\n"
|
||||
" bitIdx ++;\n"
|
||||
" }\n"
|
||||
" while( bitIdx <(m_startBit+BITS_PER_PASS) );\n"
|
||||
"\n"
|
||||
" { // generate historgram\n"
|
||||
" uint4 localKeys = make_uint4( sortData[0].m_key>>m_startBit, sortData[1].m_key>>m_startBit, \n"
|
||||
" sortData[2].m_key>>m_startBit, sortData[3].m_key>>m_startBit );\n"
|
||||
"\n"
|
||||
" generateHistogram( lIdx, wgIdx, localKeys );\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" int nBins = (1<<BITS_PER_PASS);\n"
|
||||
" if( lIdx < nBins )\n"
|
||||
" {\n"
|
||||
" u32 histValues = ldsSortData[lIdx];\n"
|
||||
"\n"
|
||||
" u32 globalAddresses = nBins*wgIdx + lIdx;\n"
|
||||
" u32 globalAddressesRadixMajor = m_numGroups*lIdx + wgIdx;\n"
|
||||
" \n"
|
||||
" ldsHistogramOut0[globalAddressesRadixMajor] = histValues;\n"
|
||||
" ldsHistogramOut1[globalAddresses] = histValues;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" { // write\n"
|
||||
" u32 offset = nElemsPerWG*wgIdx;\n"
|
||||
" uint4 dstAddr = make_uint4(offset+localAddr.x, offset+localAddr.y, offset+localAddr.z, offset+localAddr.w );\n"
|
||||
"\n"
|
||||
" sortDataIn[ dstAddr.x + 0 ] = sortData[0];\n"
|
||||
" sortDataIn[ dstAddr.x + 1 ] = sortData[1];\n"
|
||||
" sortDataIn[ dstAddr.x + 2 ] = sortData[2];\n"
|
||||
" sortDataIn[ dstAddr.x + 3 ] = sortData[3];\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"StructuredBuffer<SortData> src : register( t0 );\n"
|
||||
"StructuredBuffer<u32> histogramGlobalRadixMajor : register( t1 );\n"
|
||||
"StructuredBuffer<u32> histogramLocalGroupMajor : register( t2 );\n"
|
||||
"\n"
|
||||
"RWStructuredBuffer<SortData> dst : register( u0 );\n"
|
||||
"\n"
|
||||
"groupshared u32 ldsLocalHistogram[ 2*(1<<BITS_PER_PASS) ];\n"
|
||||
"groupshared u32 ldsGlobalHistogram[ (1<<BITS_PER_PASS) ];\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void ScatterKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" u32 lIdx = GET_LOCAL_IDX;\n"
|
||||
" u32 wgIdx = GET_GROUP_IDX;\n"
|
||||
" u32 ldsOffset = (1<<BITS_PER_PASS);\n"
|
||||
"\n"
|
||||
" // load and prefix scan local histogram\n"
|
||||
" if( lIdx < ((1<<BITS_PER_PASS)/2) )\n"
|
||||
" {\n"
|
||||
" uint2 myIdx = make_uint2(lIdx, lIdx+8);\n"
|
||||
"\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.x] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.x];\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.y] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.y];\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.x-(1<<BITS_PER_PASS)] = 0;\n"
|
||||
" ldsLocalHistogram[ldsOffset+myIdx.y-(1<<BITS_PER_PASS)] = 0;\n"
|
||||
"\n"
|
||||
" int idx = ldsOffset+2*lIdx;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-1];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-2];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-4];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
" ldsLocalHistogram[idx] += ldsLocalHistogram[idx-8];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
"\n"
|
||||
" // Propagate intermediate values through\n"
|
||||
" ldsLocalHistogram[idx-1] += ldsLocalHistogram[idx-2];\n"
|
||||
" GROUP_MEM_FENCE;\n"
|
||||
"\n"
|
||||
" // Grab and propagate for whole WG - loading the - 1 value\n"
|
||||
" uint2 localValues;\n"
|
||||
" localValues.x = ldsLocalHistogram[ldsOffset+myIdx.x-1];\n"
|
||||
" localValues.y = ldsLocalHistogram[ldsOffset+myIdx.y-1];\n"
|
||||
"\n"
|
||||
" ldsLocalHistogram[myIdx.x] = localValues.x;\n"
|
||||
" ldsLocalHistogram[myIdx.y] = localValues.y;\n"
|
||||
"\n"
|
||||
"\n"
|
||||
" ldsGlobalHistogram[myIdx.x] = histogramGlobalRadixMajor[m_numGroups*myIdx.x + wgIdx];\n"
|
||||
" ldsGlobalHistogram[myIdx.y] = histogramGlobalRadixMajor[m_numGroups*myIdx.y + wgIdx];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" GROUP_LDS_BARRIER;\n"
|
||||
"\n"
|
||||
" uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);\n"
|
||||
"\n"
|
||||
" SortData sortData[4];\n"
|
||||
" {\n"
|
||||
" uint4 globalAddr = wgIdx*WG_SIZE*NUM_PER_WI + localAddr;\n"
|
||||
" sortData[0] = src[globalAddr.x];\n"
|
||||
" sortData[1] = src[globalAddr.y];\n"
|
||||
" sortData[2] = src[globalAddr.z];\n"
|
||||
" sortData[3] = src[globalAddr.w];\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" uint cmpValue = ((1<<BITS_PER_PASS)-1);\n"
|
||||
" uint4 radix = make_uint4( (sortData[0].m_key>>m_startBit)&cmpValue, (sortData[1].m_key>>m_startBit)&cmpValue, \n"
|
||||
" (sortData[2].m_key>>m_startBit)&cmpValue, (sortData[3].m_key>>m_startBit)&cmpValue );;\n"
|
||||
"\n"
|
||||
" // data is already sorted. So simply subtract local prefix sum\n"
|
||||
" uint4 dstAddr;\n"
|
||||
" dstAddr.x = ldsGlobalHistogram[radix.x] + (localAddr.x - ldsLocalHistogram[radix.x]);\n"
|
||||
" dstAddr.y = ldsGlobalHistogram[radix.y] + (localAddr.y - ldsLocalHistogram[radix.y]);\n"
|
||||
" dstAddr.z = ldsGlobalHistogram[radix.z] + (localAddr.z - ldsLocalHistogram[radix.z]);\n"
|
||||
" dstAddr.w = ldsGlobalHistogram[radix.w] + (localAddr.w - ldsLocalHistogram[radix.w]);\n"
|
||||
"\n"
|
||||
" dst[dstAddr.x] = sortData[0];\n"
|
||||
" dst[dstAddr.y] = sortData[1];\n"
|
||||
" dst[dstAddr.z] = sortData[2];\n"
|
||||
" dst[dstAddr.w] = sortData[3];\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"[numthreads(WG_SIZE, 1, 1)]\n"
|
||||
"void CopyKernel( DEFAULT_ARGS )\n"
|
||||
"{\n"
|
||||
" dst[ GET_GLOBAL_IDX ] = src[ GET_GLOBAL_IDX ];\n"
|
||||
"}\n"
|
||||
;
|
||||
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <AdlPrimitives/Math/Math.h>
|
||||
|
||||
namespace adl
|
||||
{
|
||||
|
||||
struct SortData
|
||||
{
|
||||
SortData(){}
|
||||
SortData( u32 key, u32 value ) : m_key(key), m_value(value) {}
|
||||
|
||||
union
|
||||
{
|
||||
u32 m_key;
|
||||
struct { u16 m_key16[2]; };
|
||||
};
|
||||
u32 m_value;
|
||||
|
||||
friend bool operator <(const SortData& a, const SortData& b)
|
||||
{
|
||||
return a.m_key < b.m_key;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
};
|
||||
@@ -0,0 +1,146 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#define PATH "..\\..\\AdlPrimitives\\Sort\\RadixSortAdvancedKernels"
|
||||
#define KERNEL0 "StreamCountKernel"
|
||||
#define KERNEL1 "SortAndScatterKernel1"
|
||||
#define KERNEL2 "PrefixScanKernel"
|
||||
|
||||
template<DeviceType type>
|
||||
class RadixSortAdvanced : public RadixSortBase
|
||||
{
|
||||
public:
|
||||
typedef Launcher::BufferInfo BufferInfo;
|
||||
|
||||
enum
|
||||
{
|
||||
WG_SIZE = 128,
|
||||
NUM_PER_WI = 4,
|
||||
MAX_NUM_WORKGROUPS = 60,
|
||||
};
|
||||
|
||||
struct Data : public RadixSort<type>::Data
|
||||
{
|
||||
Kernel* m_localCountKernel;
|
||||
Kernel* m_scatterKernel;
|
||||
Kernel* m_scanKernel;
|
||||
|
||||
Buffer<u32>* m_workBuffer0;
|
||||
Buffer<SortData>* m_workBuffer1;
|
||||
Buffer<int4>* m_constBuffer[32/4];
|
||||
};
|
||||
|
||||
|
||||
static
|
||||
Data* allocate(const Device* deviceData, int maxSize, Option option = SORT_NORMAL);
|
||||
|
||||
static
|
||||
void deallocate(void* data);
|
||||
|
||||
static
|
||||
void execute(void* data, Buffer<SortData>& inout, int n, int sortBits);
|
||||
};
|
||||
|
||||
template<DeviceType type>
|
||||
typename RadixSortAdvanced<type>::Data* RadixSortAdvanced<type>::allocate(const Device* deviceData, int maxSize, Option option)
|
||||
{
|
||||
ADLASSERT( type == deviceData->m_type );
|
||||
|
||||
const char* src[] = { 0, 0, 0 };
|
||||
|
||||
Data* data = new Data;
|
||||
data->m_option = option;
|
||||
data->m_deviceData = deviceData;
|
||||
|
||||
data->m_localCountKernel = deviceData->getKernel( PATH, KERNEL0, 0, src[type] );
|
||||
data->m_scatterKernel = deviceData->getKernel( PATH, KERNEL1, 0, src[type] );
|
||||
data->m_scanKernel = deviceData->getKernel( PATH, KERNEL2, 0, src[type] );
|
||||
|
||||
data->m_workBuffer0 = new Buffer<u32>( deviceData, MAX_NUM_WORKGROUPS*16 );
|
||||
data->m_workBuffer1 = new Buffer<SortData>( deviceData, maxSize );
|
||||
for(int i=0; i<32/4; i++)
|
||||
data->m_constBuffer[i] = new Buffer<int4>( deviceData, 1, BufferBase::BUFFER_CONST );
|
||||
data->m_maxSize = maxSize;
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template<DeviceType type>
|
||||
void RadixSortAdvanced<type>::deallocate(void* rawData)
|
||||
{
|
||||
Data* data = (Data*)rawData;
|
||||
|
||||
delete data->m_workBuffer0;
|
||||
delete data->m_workBuffer1;
|
||||
for(int i=0; i<32/4; i++)
|
||||
delete data->m_constBuffer[i];
|
||||
|
||||
delete data;
|
||||
}
|
||||
|
||||
template<DeviceType type>
|
||||
void RadixSortAdvanced<type>::execute(void* rawData, Buffer<SortData>& inout, int n, int sortBits)
|
||||
{
|
||||
Data* data = (Data*)rawData;
|
||||
|
||||
ADLASSERT( sortBits == 32 );
|
||||
|
||||
ADLASSERT( NUM_PER_WI == 4 );
|
||||
ADLASSERT( n%(WG_SIZE*NUM_PER_WI) == 0 );
|
||||
ADLASSERT( MAX_NUM_WORKGROUPS < 128*8/16 );
|
||||
|
||||
Buffer<SortData>* src = &inout;
|
||||
Buffer<SortData>* dst = data->m_workBuffer1;
|
||||
|
||||
const Device* deviceData = data->m_deviceData;
|
||||
|
||||
int nBlocks = n/(NUM_PER_WI*WG_SIZE);
|
||||
const int nWorkGroupsToExecute = min2((int)MAX_NUM_WORKGROUPS, nBlocks);
|
||||
int nBlocksPerGroup = (nBlocks+nWorkGroupsToExecute-1)/nWorkGroupsToExecute;
|
||||
ADLASSERT( nWorkGroupsToExecute <= MAX_NUM_WORKGROUPS );
|
||||
|
||||
int4 constBuffer = make_int4(0, nBlocks, nWorkGroupsToExecute, nBlocksPerGroup);
|
||||
|
||||
int iPass = 0;
|
||||
int startBit = 0;
|
||||
for(int startBit=0; startBit<32; startBit+=4, iPass++)
|
||||
{
|
||||
constBuffer.x = startBit;
|
||||
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( data->m_workBuffer0 ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_localCountKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
|
||||
launcher.launch1D( WG_SIZE* nWorkGroupsToExecute, WG_SIZE );
|
||||
}
|
||||
|
||||
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( data->m_workBuffer0 ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_scanKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
|
||||
launcher.launch1D( WG_SIZE, WG_SIZE );
|
||||
}
|
||||
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( data->m_workBuffer0, true ), BufferInfo( src ), BufferInfo( dst ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_scatterKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
|
||||
launcher.launch1D( WG_SIZE*nWorkGroupsToExecute, WG_SIZE );
|
||||
}
|
||||
|
||||
swap2( src, dst );
|
||||
}
|
||||
}
|
||||
|
||||
#undef PATH
|
||||
#undef KERNEL0
|
||||
#undef KERNEL1
|
||||
#undef KERNEL2
|
||||
@@ -0,0 +1,149 @@
|
||||
/*
|
||||
2011 Takahiro Harada
|
||||
*/
|
||||
|
||||
#define PATH "..\\..\\opencl\\primitives\\AdlPrimitives\\Sort\\RadixSortSimpleKernels"
|
||||
#define KERNEL0 "LocalCountKernel"
|
||||
#define KERNEL1 "ScatterKernel"
|
||||
|
||||
#include <AdlPrimitives/Sort/RadixSortSimpleCL.h>
|
||||
#include <AdlPrimitives/Sort/RadixSortSimpleDX11.h>
|
||||
|
||||
template<DeviceType type>
|
||||
class RadixSortSimple : public RadixSortBase
|
||||
{
|
||||
public:
|
||||
typedef Launcher::BufferInfo BufferInfo;
|
||||
|
||||
enum
|
||||
{
|
||||
WG_SIZE = 128,
|
||||
NUM_PER_WI = 4,
|
||||
};
|
||||
|
||||
struct Data : public RadixSort<type>::Data
|
||||
{
|
||||
Kernel* m_localCountKernel;
|
||||
Kernel* m_scatterKernel;
|
||||
|
||||
Buffer<u32>* m_workBuffer0;
|
||||
Buffer<u32>* m_workBuffer1;
|
||||
Buffer<SortData>* m_workBuffer2;
|
||||
Buffer<int4>* m_constBuffer[4];
|
||||
};
|
||||
|
||||
|
||||
static
|
||||
Data* allocate(const Device* deviceData, int maxSize, Option option = SORT_NORMAL);
|
||||
|
||||
static
|
||||
void deallocate(void* data);
|
||||
|
||||
static
|
||||
void execute(void* data, Buffer<SortData>& inout, int n, int sortBits);
|
||||
};
|
||||
|
||||
template<DeviceType type>
|
||||
typename RadixSortSimple<type>::Data* RadixSortSimple<type>::allocate(const Device* deviceData, int maxSize, Option option)
|
||||
{
|
||||
ADLASSERT( type == deviceData->m_type );
|
||||
|
||||
const char* src[] =
|
||||
#if defined(ADL_LOAD_KERNEL_FROM_STRING)
|
||||
{radixSortSimpleKernelsCL, radixSortSimpleKernelsDX11};
|
||||
#else
|
||||
{ 0, 0 };
|
||||
#endif
|
||||
u32 maxNumGroups = (maxSize+WG_SIZE*NUM_PER_WI-1)/(WG_SIZE*NUM_PER_WI);
|
||||
|
||||
Data* data = new Data;
|
||||
data->m_option = option;
|
||||
data->m_deviceData = deviceData;
|
||||
|
||||
data->m_localCountKernel = deviceData->getKernel( PATH, KERNEL0, 0, src[type] );
|
||||
data->m_scatterKernel = deviceData->getKernel( PATH, KERNEL1, 0, src[type] );
|
||||
|
||||
data->m_scanData = PrefixScan<type>::allocate( deviceData, maxSize );
|
||||
|
||||
data->m_workBuffer0 = new Buffer<u32>( deviceData, maxNumGroups*256 );
|
||||
data->m_workBuffer1 = new Buffer<u32>( deviceData, maxNumGroups*256 );
|
||||
data->m_workBuffer2 = new Buffer<SortData>( deviceData, maxSize );
|
||||
data->m_constBuffer[0] = new Buffer<int4>( deviceData, 1, BufferBase::BUFFER_CONST );
|
||||
data->m_constBuffer[1] = new Buffer<int4>( deviceData, 1, BufferBase::BUFFER_CONST );
|
||||
data->m_constBuffer[2] = new Buffer<int4>( deviceData, 1, BufferBase::BUFFER_CONST );
|
||||
data->m_constBuffer[3] = new Buffer<int4>( deviceData, 1, BufferBase::BUFFER_CONST );
|
||||
data->m_maxSize = maxSize;
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template<DeviceType type>
|
||||
void RadixSortSimple<type>::deallocate(void* rawData)
|
||||
{
|
||||
Data* data = (Data*)rawData;
|
||||
|
||||
delete data->m_workBuffer0;
|
||||
delete data->m_workBuffer1;
|
||||
delete data->m_workBuffer2;
|
||||
delete data->m_constBuffer[0];
|
||||
delete data->m_constBuffer[1];
|
||||
delete data->m_constBuffer[2];
|
||||
delete data->m_constBuffer[3];
|
||||
|
||||
PrefixScan<type>::deallocate( data->m_scanData );
|
||||
|
||||
delete data;
|
||||
}
|
||||
|
||||
template<DeviceType type>
|
||||
void RadixSortSimple<type>::execute(void* rawData, Buffer<SortData>& inout, int n, int sortBits)
|
||||
{
|
||||
Data* data = (Data*)rawData;
|
||||
|
||||
ADLASSERT( sortBits == 32 );
|
||||
ADLASSERT( n%512 == 0 );
|
||||
ADLASSERT( n <= data->m_maxSize );
|
||||
|
||||
Buffer<SortData>* src = &inout;
|
||||
Buffer<SortData>* dst = data->m_workBuffer2;
|
||||
|
||||
const Device* deviceData = data->m_deviceData;
|
||||
|
||||
int numGroups = (n+WG_SIZE*NUM_PER_WI-1)/(WG_SIZE*NUM_PER_WI);
|
||||
|
||||
int4 constBuffer;
|
||||
|
||||
int iPass = 0;
|
||||
for(int startBit=0; startBit<32; startBit+=8, iPass++)
|
||||
{
|
||||
constBuffer.x = startBit;
|
||||
constBuffer.y = numGroups;
|
||||
constBuffer.z = WG_SIZE;
|
||||
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( data->m_workBuffer0 ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_localCountKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
|
||||
launcher.launch1D( WG_SIZE*numGroups, WG_SIZE );
|
||||
}
|
||||
|
||||
PrefixScan<type>::execute( data->m_scanData, *data->m_workBuffer0, *data->m_workBuffer1, numGroups*256 );
|
||||
|
||||
{
|
||||
BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( dst ), BufferInfo( data->m_workBuffer1 ) };
|
||||
|
||||
Launcher launcher( deviceData, data->m_scatterKernel );
|
||||
launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
|
||||
launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
|
||||
launcher.launch1D( WG_SIZE*numGroups, WG_SIZE );
|
||||
}
|
||||
|
||||
swap2( src, dst );
|
||||
}
|
||||
}
|
||||
|
||||
#undef PATH
|
||||
#undef KERNEL0
|
||||
#undef KERNEL1
|
||||
Reference in New Issue
Block a user