Question

I was impressed by the speedup that was gained HERE by using numba

Today I found in SO a question from someone who is looking to speedup his code. So I thought let' s have a look what we can achieve with numba.

Here is the code:

from numba import autojit
from time import time

LIMIT = pow(10,6)

def primes(limit):
    # Keep only odd numbers in sieve, mapping from index to number is
    # num = 2 * idx + 3
    # The square of the number corresponding to idx then corresponds to:
    # idx2 = 2*idx*idx + 6*idx + 3
    sieve = [True] * (limit // 2)
    prime_numbers = set([2])
    for j in range(len(sieve)):
        if sieve[j]:
            new_prime = 2*j + 3
            prime_numbers.add(new_prime)
            for k in range((2*j+6)*j+3, len(sieve), new_prime):
                sieve[k] = False
    return list(prime_numbers)


numba_primes = autojit(primes)



start = time()
numba_primes(LIMIT)
end=time()
print("Numba: Time Taken : ",end-start)

start = time()
primes(LIMIT)
end=time()
print("Python: Time Taken : ",end-start)

The result:

('Numba: Time Taken : ', 0.68790602684021)
('Python: Time Taken : ', 0.12417221069335938)

Why is this happening? It seems that with numba this piece of code is not getting any faster!

Was it helpful?

Solution

Here is a numba-ized version (using Numba 0.13) of your code this is optimized by using numpy arrays

import numpy as np
import numba

# You could also just use @numba.jit or @numba.jit(nopython=True)
# here and get comparable timings.
@numba.jit('void(uint8[:])', nopython=True)
def primes_util(sieve):
    ssz = sieve.shape[0]
    for j in xrange(ssz):
        if sieve[j]:
            new_prime = 2*j + 3
            for k in xrange((2*j+6)*j+3, ssz, new_prime):
                sieve[k] = False

def primes_numba(limit):
    sieve = np.ones(limit // 2, dtype=np.uint8)
    primes_util(sieve)

    return [2] + (np.nonzero(sieve)[0]*2 + 3).tolist()

and then a comparison with timings:

In [112]: %timeit primes(LIMIT)
1 loops, best of 3: 221 ms per loop

In [113]: %timeit primes_numba(LIMIT)
100 loops, best of 3: 11 ms per loop

In [114]:

a = set(primes(LIMIT))
b = set(primes_numba(LIMIT))

a == b
Out[114]:

True

That's a 20x speedup, although there are probably even further optimization that could be made. Without the jit decorator, the numba version runs in about 300 ms on my machine. The actual call to primes_util only takes about 5 ms and the rest is the call to np.nonzero and the conversion to a list.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top