Question

I've been trying to design a fast binary exponentiation implementation in OpenCL. My current implementation is very similar to the one in this book about pi.

// Returns 16^n mod ak
inline double expm (long n, double ak)
{
    double r = 16.0;
    long nt;

    if (ak == 1) return 0.;
    if (n == 0) return 1;
    if (n == 1) return fmod(16.0, ak);

    for (nt=1; nt <= n; nt <<=1);

    nt >>= 2;

    do
    {
        r = fmod(r*r, ak);
        if ((n & nt) != 0)
            r = fmod(16.0*r, ak);
        nt >>= 1;
    } while (nt != 0);
    return r;
}

Is there room for improvement? Right now my program is spending the vast majority of it's time in this function.

Was it helpful?

Solution

My first thought is to vectorize it, for a potential speed up of ~1.6x. This uses 5 multiplies per loop compared to 2 multiplies in the original, but with approximately a quarter the number of loops for sufficiently large N. Converting all the doubles to longs, and swapping out the fmods for %s may provide some speed up depending on the exact GPU used and whatever.

inline double expm(long n, double ak) {

    double4 r = (1.0, 1.0, 1.0, 1.0);
    long4 ns = n & (0x1111111111111111, 0x2222222222222222, 0x4444444444444444,
            0x8888888888888888);
    long nt;

    if(ak == 1) return 0.;

    for(nt=15; nt<n; nt<<=4); //This can probably be vectorized somehow as well.

    do {
        double4 tmp = r*r;
        tmp = tmp*tmp;
        tmp = tmp*tmp;
        r = fmod(tmp*tmp, ak); //Raise it to the 16th power, 
                                       //same as multiplying the exponent 
                                       //(of the result) by 16, same as
                                       //bitshifting the exponent to the right 4 bits.

        r = select(fmod(r*(16.0,256.0,65536.0, 4294967296.0), ak), r, (ns & nt) - 1);
        nt >>= 4;
    } while(nt != 0); //Process n four bits at a time.

    return fmod(r.x*r.y*r.z*r.w, ak); //And then combine all of them.
}

Edit: I'm pretty sure it works now.

OTHER TIPS

  • The loop to extract nt = log2(n); can be replaced by
    if (n & 1) ...; n >>= 1;
    in the do-while loop.
  • Given that initially r = 16;, fmod(r*r, ak) vs fmod(16*r,ak) can be easily delayed to calculate the modulo only every Nth iteration or so -- Loop unrolling?
  • Also why fmod?
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top