cublasSgemmBatched usage with jcuda
Question
I've been trying to use cublasSgemmBatched() function in jcuda for matrix multiplication and I'm not sure how to properly handle pointer passing and vectors of batched matrices. I will be really thankful if someone knows how to modify my code to properly handle this problem. In this example, C array stays unchanged after cublasGetVector.
public static void SsmmBatchJCublas(int m, int n, int k, float A[], float B[]){
// Create a CUBLAS handle
cublasHandle handle = new cublasHandle();
cublasCreate(handle);
// Allocate memory on the device
Pointer d_A = new Pointer();
Pointer d_B = new Pointer();
Pointer d_C = new Pointer();
cudaMalloc(d_A, m*k * Sizeof.FLOAT);
cudaMalloc(d_B, n*k * Sizeof.FLOAT);
cudaMalloc(d_C, m*n * Sizeof.FLOAT);
float[] C = new float[m*n];
// Copy the memory from the host to the device
cublasSetVector(m*k, Sizeof.FLOAT, Pointer.to(A), 1, d_A, 1);
cublasSetVector(n*k, Sizeof.FLOAT, Pointer.to(B), 1, d_B, 1);
cublasSetVector(m*n, Sizeof.FLOAT, Pointer.to(C), 1, d_C, 1);
Pointer[] Aarray = new Pointer[]{d_A};
Pointer AarrayPtr = Pointer.to(Aarray);
Pointer[] Barray = new Pointer[]{d_B};
Pointer BarrayPtr = Pointer.to(Barray);
Pointer[] Carray = new Pointer[]{d_C};
Pointer CarrayPtr = Pointer.to(Carray);
// Execute sgemm
Pointer pAlpha = Pointer.to(new float[]{1});
Pointer pBeta = Pointer.to(new float[]{0});
cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, pAlpha, AarrayPtr, Aarray.length, BarrayPtr, Barray.length, pBeta, CarrayPtr, Carray.length, Aarray.length);
// Copy the result from the device to the host
cublasGetVector(m*n, Sizeof.FLOAT, d_C, 1, Pointer.to(C), 1);
// Clean up
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
cublasDestroy(handle);
}
Solution
I asked on official jcuda forum and quickly received the answer here.
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow