سؤال

I have a 2D Matrix Multiplication program using the following kernel:

__global__ void multKernel(int *a, int *b, int *c, int N)
{
    int column  = threadIdx.x + blockDim.x * blockIdx.x;
    int row     = threadIdx.y + blockDim.y * blockIdx.y;

    int index = row * N + column;

    if(column < N && row < N)
    {
        c[index] = a[index] * b[index];
    }
}

Now, I'd like to create a 3D matrix multiplication kernel, but I'm having trouble finding examples of how do create one (also, I'm terrible at reading mathematical formulae, it's something I need to improve on).

I know the GPU example will involve using

threadIdx.z

and so on, but I'm a bit lost with how to do it.

Could anyone point me in the right direction to either some formulae or sample code? Or even provide a basic example? I have a CPU example which should work, I think.

void matrixMult3D(int *a, int *b, int *c, int *z, int N)
{
    int index;

    for(int column = 0; column < N; column++)
    {
        for(int row = 0; row < N; row++)
        {
            for (int z = 0; z < N; z++)
            {
                index = row * N + column + z;
                c[index] = a[index] * b[index] * z[index];
            }
        }
    }
}

Am I at least on the right track?

هل كانت مفيدة؟

المحلول

Because what you are actually doing is just an element-wise product (I hesitate to call it a Hadamard Product because that isn't defined for hyper matrices AFAIK), you don't need to do anything differently from the simplest 1D version of your kernel code. Something like this:

template<int ndim>
__global__ void multKernel(int *a, int *b, int *c, int *z, int N)
{
    int idx  = threadIdx.x + blockDim.x * blockIdx.x;
    int stride = blockDim.x * gridDim.x;

    int idxmax = 1;
    #pragma unroll
    for(int i=0; i < ndim; i++) {
        idxmax *= N;
    }
    for(; idx < idxmax; idx+=stride) {
       c[index] = a[index] * b[index] * z[index];
    }
}

[disclaimer: code written in browser, never compiled or run. use at own risk]

would work for any dimension of array with dimensions N (ndim=1), N*N (ndim=2), N*N*N (ndim=3), etc.

مرخصة بموجب: CC-BY-SA مع الإسناد
لا تنتمي إلى StackOverflow
scroll top