What is the fastest way to find integer square root using bit shifts?
-
12-06-2021 - |
题
I was looking for the fastest method to calculate the square root(integer) of a number(integer). I came across this solution in wikipedia which finds the square root of a number(if its a perfect square) or the square root of its nearest lower perfect square (if the given number is not a perfect square:
short isqrt(short num) {
short res = 0;
short bit = 1 << 14; // The second-to-top bit is set: 1L<<30 for long
// "bit" starts at the highest power of four <= the argument.
while (bit > num)
bit >>= 2;
while (bit != 0) {
if (num >= res + bit) {
num -= res + bit;
res = (res >> 1) + bit;
}
else
res >>= 1;
bit >>= 2;
}
return res;
}
I tried a lot of test runs to trace the algorithm but I do not seem to understand the portion inside while(bit!=0)
. Can anybody explain this part to me?
解决方案
I traced out a few small examples too, and I think I got it. As best as I understand it, the algorithm is building up the answer one binary digit at a time, from highest bit to lowest bit.
Let "num_init" be the value of num at the beginning of the function. Suppose at some iteration, we have that bit = 4^x and that num is equal to some value "num_curr" (a quick glance shows that until bit is 0, it is always a power of 4). Then res is of the form y*2^(x+1), where y^2 + num_curr = num_init, and y is less than the actual answer, but within 2^x.
This invariant on the values of num, res, and bit is going to be key. The way this is done in the code is that
while (bit != 0) {
....
}
is moving our imaginary pointer left to right, and at each step we determine whether this bit is 0 or 1.
Going to the first if statement, suppose our imaginary "built-up" integer is equal to y, and we're looking at the 2^x bit. Then, the bit is 1 iff the original value of num is at least (y + 2^x)^2 = y^2 + y*2^(x+1) + 4^x. In other words, the bit is one if the value of num at that point is at least y*2^(x+1) + 4^x (since we have the invariant that the value of num has dropped by y^2). Conveniently enough, res = y*2^(x+1) and bit = 4^x. We then get the point behind
if (num >= res + bit) {
num -= res + bit;
res = (res >> 1) + bit;
}
else
res >>= 1;
which adds a 1 bit at our imaginary spot if necessary, then updates res and num to keep the invariant. Lastly
bit >>= 2;
updates bit and moves everything along one step.