I'm providing a late answer to this question to remove it from the unanswered list.
To do what you want to achieve you can define an array of unsigned int
s of length N/32
, where N
is the length of the arrays you are comparing. Then you can use atomicAdd
to write each bit of such an array, depending on whether two elements of the arrays are equal or not.
Below I'm providing a simple example:
#include <iostream>
#include <thrust\device_vector.h>
__device__ unsigned int __ballot_non_atom(int predicate)
{
if (predicate != 0) return (1 << (threadIdx.x % 32));
else return 0;
}
__global__ void check_if_equal_elements(float* d_vec1_ptr, float* d_vec2_ptr, unsigned int* d_result, int Num_Warps_per_Block)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const unsigned int warp_num = threadIdx.x >> 5;
atomicAdd(&d_result[warp_num+blockIdx.x*Num_Warps_per_Block],__ballot_non_atom(!(d_vec1_ptr[tid] == d_vec2_ptr[tid])));
}
// --- Credit to "C printing bits": http://stackoverflow.com/questions/9280654/c-printing-bits
void printBits(unsigned int num){
unsigned int size = sizeof(unsigned int);
unsigned int maxPow = 1<<(size*8-1);
int i=0;
for(;i<size;++i){
for(;i<size*8;++i){
// print last bit and shift left.
printf("%u ",num&maxPow ? 1 : 0);
num = num<<1;
}
}
}
void main(void)
{
const int N = 64;
thrust::device_vector<float> d_vec1(N,1.f);
thrust::device_vector<float> d_vec2(N,1.f);
d_vec2[3] = 3.f;
d_vec2[7] = 4.f;
unsigned int Num_Threads_per_Block = 64;
unsigned int Num_Blocks_per_Grid = 1;
unsigned int Num_Warps_per_Block = Num_Threads_per_Block/32;
unsigned int Num_Warps_per_Grid = (Num_Threads_per_Block*Num_Blocks_per_Grid)/32;
thrust::device_vector<unsigned int> d_result(Num_Warps_per_Grid,0);
check_if_equal_elements<<<Num_Blocks_per_Grid,Num_Threads_per_Block>>>((float*)thrust::raw_pointer_cast(d_vec1.data()),
(float*)thrust::raw_pointer_cast(d_vec2.data()),
(unsigned int*)thrust::raw_pointer_cast(d_result.data()),
Num_Warps_per_Block);
unsigned int val = d_result[1];
printBits(val);
val = d_result[0];
printBits(val);
getchar();
}