Your pow
function is inefficient.
for(int k = 0 ; k < root -2 ; k++)
tmpx = _mm256_mul_ps(x,tmpx);
In your example you're taking the 29th root. You need pow(x, 29-1) = x^28
. Currently you use 27 multiplications for that but it's possible to do that in only six multiplications.
x^28 = (x^4)*(x^8)*(x^16)
x^4 = y -> 2 multiplications
x^8 = y*y = z -> 1 multiplication
x^16 = z^2 = w-> 1 multiplications
y*z*w -> 2 multiplications
6 multiplications in total
Here is an improved version of you code which is about twice as fast on my system. It uses a new pow_avx_fast
function which I created which does x^n for 8 floats at once using AVX. It does e.g. x^28 in 6 multiplications instead of 27. Please see further into my answer. I found a version which finds the result within some tolerance xacc
. This could be much faster if the convergence happens quick.
inline __m256 pow_avx_fast(__m256 x, const int n) {
//n must be greater than zero
if(n%2 == 0) {
return pow_avx_fast(_mm256_mul_ps(x, x), n/2);
}
else {
if(n>1) return _mm256_mul_ps(x,pow_avx_fast(_mm256_mul_ps(x, x), (n-1)/2));
return x;
}
}
inline __m256 SSEnthRoot_fast(__m256 a, int root) {
// n_x+1 = (1/root)*((root-1) * x + a / pow(x,root-1))
__m256 R = _mm256_set1_ps((float)root);
__m256 Ni = _mm256_rcp_ps(R);
__m256 Nm = _mm256_set1_ps((float)(root -1));
__m256 x = _mm256_mul_ps(a,Ni);
for(int i = 0; i < 20 ; i ++) {
__m256 tmpx = pow_avx_fast(x, root-1);
//f over f'
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
//fmac with Ni*X+tar
//tar = _mm256_fmadd_ps(Nm,x,tar);
tar = _mm256_add_ps(_mm256_mul_ps(Nm,x),tar);
//Multiplied by the inverse of power
x = _mm256_mul_ps(Ni,tar);
}
return x;
}
For more information how to write an efficient pow
function see these links http://en.wikipedia.org/wiki/Addition-chain_exponentiation and
http://en.wikipedia.org/wiki/Exponentiation_by_squaring
Also, your initial guess might not be so good. Here is scalar code to find the nth root based on your method (but using the math pow
function which is probably faster than yours). It takes about 50 iterations to solve the 4th root of 16 (which is 2). For the 20 iterations you use it returns over 4000 which is no where close to 2.0. So you will need to adjust your method to do enough iterations to ensure a reasonable answer within some tolerance.
float fx(float a, int n, float x) {
return 1.0f/n * ((n-1)*x + a/pow(x, n-1));
}
float scalar_nthRoot_v2(float a, int root) {
//sets initaial guess to 1 / (a * root)
float x = 1.0f/(a*root);
printf("x0 %f\n", x);
for(int i = 0; i<50; i++) {
x = fx(a, root, x);
printf("x %f\n", x);
}
return x;
}
I got the formula for Newtons method from here. http://en.wikipedia.org/wiki/Nth_root_algorithm
Here is a version of your function which gives the result within a certain tolerance xacc
or quits if no convergence after nmax
iterations. This function could be much faster than your method if the convergence happens in less than 20 iterations. It requires that all eight floats converge at once. In other words, if seven converge and one does not then the other seven have to wait for the one that does not converge. That's the problem with SIMD (on the GPU as well) but in general it's still faster than doing it without SIMD.
int get_mask(const __m256 dx, const float xacc) {
__m256i mask = _mm256_castps_si256(_mm256_cmp_ps(dx, _mm256_set1_ps(xacc), _CMP_GT_OQ));
return _mm_movemask_epi8(_mm256_castsi256_si128(mask)) + _mm_movemask_epi8(_mm256_extractf128_si256(mask,1));
}
inline __m256 SSEnthRoot_fast_xacc(const __m256 a, const int root, const int nmax, float xacc) {
// n_x+1 = (1/root)*(root * x + a / pow(x,root))
__m256 R = _mm256_set1_ps((float)root);
__m256 Ni = _mm256_rcp_ps(R);
//__m256 Ni = _mm256_set1_ps(1.0f/root);
__m256 Nm = _mm256_set1_ps((float)(root -1));
__m256 x = _mm256_mul_ps(a,Ni);
for(int i = 0; i <nmax ; i ++) {
__m256 tmpx = pow_avx_fast(x, root-1);
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
//tar = _mm256_fmadd_ps(Nm,x,tar);
tar = _mm256_add_ps(_mm256_mul_ps(Nm,x),tar);
tmpx = _mm256_mul_ps(Ni,tar);
__m256 dx = _mm256_sub_ps(tmpx,x);
dx = _mm256_max_ps(_mm256_sub_ps(_mm256_setzero_ps(), dx), dx); //fabs(dx)
int cnt = get_mask(dx, xacc);
if(cnt == 0) return x;
x = tmpx;
}
return x; //at least one value out of eight did not converge by nmax.
}
Here is a more general version of the pow function for avx which works for n<=0 as well.
__m256 pow_avx(__m256 x, const int n) {
if(n<0) {
return pow_avx(_mm256_rcp_ps(x), -n);
}
else if(n == 0) {
return _mm256_set1_ps(1.0f);
}
else if(n == 1) {
return x;
}
else if(n%2 ==0) {
return pow_avx(_mm256_mul_ps(x, x), n/2);
}
else {
return _mm256_mul_ps(x,pow_avx(_mm256_mul_ps(x, x), (n-1)/2));
}
}
Some other suggestions
You can use a SIMD math library which finds the nth root. SIMD math libraries for SSE and AVX
For Intel you can use SVML which is expensive and closed source (Intel's OpenCL driver uses SVML so with that you can get it for free). For AMD you can use LIBM which is free but closed source. There are several open source SIMD math libraries such as http://software-lisc.fbk.eu/avx_mathfun/ and https://bitbucket.org/eschnett/vecmathlib/wiki/Home