Question

I've been immensely frustrated with many of the implementations of python radix sort out there on the web.

They consistently use a radix of 10 and get the digits of the numbers they iterate over by dividing by a power of 10 or taking the log10 of the number. This is incredibly inefficient, as log10 is not a particularly quick operation compared to bit shifting, which is nearly 100 times faster!

A much more efficient implementation uses a radix of 256 and sorts the number byte by byte. This allows for all of the 'byte getting' to be done using the ridiculously quick bit operators. Unfortunately, it seems that absolutely nobody out there has implemented a radix sort in python that uses bit operators instead of logarithms.

So, I took matters into my own hands and came up with this beast, which runs at about half the speed of sorted on small arrays and runs nearly as quickly on larger ones (e.g. len around 10,000,000):

import itertools

def radix_sort(unsorted):
    "Fast implementation of radix sort for any size num."
    maximum, minimum = max(unsorted), min(unsorted)

    max_bits = maximum.bit_length()
    highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1

    min_bits = minimum.bit_length()
    lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1

    sorted_list = unsorted
    for offset in xrange(lowest_byte, highest_byte):
        sorted_list = radix_sort_offset(sorted_list, offset)

    return sorted_list

def radix_sort_offset(unsorted, offset):
    "Helper function for radix sort, sorts each offset."
    byte_check = (0xFF << offset*8)

    buckets = [[] for _ in xrange(256)]

    for num in unsorted:
        byte_at_offset = (num & byte_check) >> offset*8
        buckets[byte_at_offset].append(num)

    return list(itertools.chain.from_iterable(buckets))

This version of radix sort works by finding which bytes it has to sort by (if you pass it only integers below 256, it'll sort just one byte, etc.) then sorting each byte from LSB up by dumping them into buckets in order then just chaining the buckets together. Repeat this for each byte that needs to be sorted and you have your nice sorted array in O(n) time.

However, it's not as fast as it could be, and I'd like to make it faster before I write about it as a better radix sort than all the other radix sorts out there.

Running cProfile on this tells me that a lot of time is being spent on the append method for lists, which makes me think that this block:

    for num in unsorted:
        byte_at_offset = (num & byte_check) >> offset*8
        buckets[byte_at_offset].append(num)

in radix_sort_offset is eating a lot of time. This is also the block that, if you really look at it, does 90% of the work for the whole sort. This code looks like it could be numpy-ized, which I think would result in quite a performance boost. Unfortunately, I'm not very good with numpy's more complex features so haven't been able to figure that out. Help would be very appreciated.

I'm currently using itertools.chain.from_iterable to flatten the buckets, but if anyone has a faster suggestion I'm sure it would help as well.

Originally, I had a get_byte function that returned the nth byte of a number, but inlining the code gave me a huge speed boost so I did it.

Any other comments on the implementation or ways to squeeze out more performance are also appreciated. I want to hear anything and everything you've got.

Was it helpful?

Solution

You already realized that

for num in unsorted:
    byte_at_offset = (num & byte_check) >> offset*8
    buckets[byte_at_offset].append(num)

is where most of the time goes - good ;-)

There are two standard tricks for speeding that kind of thing, both having to do with moving invariants out of loops:

  1. Compute "offset*8" outside the loop. Store it in a local variable. Save a multiplication per iteration.
  2. Add bucketappender = [bucket.append for bucket in buckets] outside the loop. Saves a method lookup per iteration.

Combine them, and the loop looks like:

for num in unsorted:
    bucketappender[(num & byte_check) >> ofs8](num)

Collapsing it to one statement also saves a pair of local vrbl store/fetch opcodes per iteration.

But, at a higher level, the standard way to speed radix sort is to use a larger radix. What's magical about 256? Nothing, apart from that it's convenient for bit-shifting. But so are 512, 1024, 2048 ... it's a classical time/space tradeoff.

PS: for very long numbers,

(num >> offset*8) & 0xff

will run faster. That's because your num & byte_check takes time proportional to log(num) - it generally has to create an integer about as big as num.

OTHER TIPS

This is an old thread, but I came across this when looking to radix sort an array of positive integers. I was trying to see if I can do any better than the already wickedly fast timsort (hats off to you again, Tim Peters) which implements python's builtin sorted and sort! Either I don't understand certain aspects of the above code, or if I do, the code as presented above has some problems IMHO.

  1. It only sorts bytes starting with the highest byte of the smallest item and ending with the highest byte of the biggest item. This may be okay in some cases of special data. But in general the approach fails to differentiate items which differ on account of the lower bits. For example:

    arr=[65535,65534]
    radix_sort(arr)
    

    produces the wrong output:

    [65535, 65534]
    
  2. The range used to loop over the helper function is not correct. What I mean is that if lowest_byte and highest_byte are the same, execution of the helper function is altogether skipped. BTW I had to change xrange to range in 2 places.

  3. With modifications to address the above 2 points, I got it to work. But it is taking 10-20 times the time of python's builtin sorted or sort! I know timsort is very efficient and takes advantage of already sorted runs in the data. But I was trying to see if I can use the prior knowledge that my data is all positive integers to some advantage in my sorting. Why is the radix sort doing so badly compared to timsort? The array sizes I was using are in the order of 80K items. Is it because the timsort implementation in addition to its algorithmic efficiency has also other efficiencies stemming from possible use of low level libraries? Or am I missing something entirely? The modified code I used is below:

    import itertools
    
    def radix_sort(unsorted):
        "Fast implementation of radix sort for any size num."
        maximum, minimum = max(unsorted), min(unsorted)
    
        max_bits = maximum.bit_length()
        highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1
    
    #    min_bits = minimum.bit_length()
    #    lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1
    
        sorted_list = unsorted
    #    xrange changed to range, lowest_byte deleted from the arguments
        for offset in range(highest_byte):
            sorted_list = radix_sort_offset(sorted_list, offset)
    
        return sorted_list
    
    def radix_sort_offset(unsorted, offset):
        "Helper function for radix sort, sorts each offset."
        byte_check = (0xFF << offset*8)
    
    #    xrange changed to range
        buckets = [[] for _ in range(256)]
    
        for num in unsorted:
            byte_at_offset = (num & byte_check) >> offset*8
            buckets[byte_at_offset].append(num)
    
        return list(itertools.chain.from_iterable(buckets))
    

You could simply use one of the existing C or C++ implementations, such as example, integer_sort from Boost.Sort or u4_sort from usort. It is surprisingly easy to call native C or C++ code from Python, see How to sort an array of integers faster than quicksort?

I totally get your frustration. Although it's been more than 2 years, numpy still does not have radix sort. I will let the NumPy developers know that they could simply grab one of the existing implementations; licensing should not be an issue.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top