I've been immensely frustrated with many of the implementations of python radix sort out there on the web.
They consistently use a radix of 10 and get the digits of the numbers they iterate over by dividing by a power of 10 or taking the log10 of the number. This is incredibly inefficient, as log10 is not a particularly quick operation compared to bit shifting, which is nearly 100 times faster!
A much more efficient implementation uses a radix of 256 and sorts the number byte by byte. This allows for all of the 'byte getting' to be done using the ridiculously quick bit operators. Unfortunately, it seems that absolutely nobody out there has implemented a radix sort in python that uses bit operators instead of logarithms.
So, I took matters into my own hands and came up with this beast, which runs at about half the speed of sorted on small arrays and runs nearly as quickly on larger ones (e.g. len
around 10,000,000):
import itertools
def radix_sort(unsorted):
"Fast implementation of radix sort for any size num."
maximum, minimum = max(unsorted), min(unsorted)
max_bits = maximum.bit_length()
highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1
min_bits = minimum.bit_length()
lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1
sorted_list = unsorted
for offset in xrange(lowest_byte, highest_byte):
sorted_list = radix_sort_offset(sorted_list, offset)
return sorted_list
def radix_sort_offset(unsorted, offset):
"Helper function for radix sort, sorts each offset."
byte_check = (0xFF << offset*8)
buckets = [[] for _ in xrange(256)]
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
return list(itertools.chain.from_iterable(buckets))
This version of radix sort works by finding which bytes it has to sort by (if you pass it only integers below 256, it'll sort just one byte, etc.) then sorting each byte from LSB up by dumping them into buckets in order then just chaining the buckets together. Repeat this for each byte that needs to be sorted and you have your nice sorted array in O(n) time.
However, it's not as fast as it could be, and I'd like to make it faster before I write about it as a better radix sort than all the other radix sorts out there.
Running cProfile
on this tells me that a lot of time is being spent on the append
method for lists, which makes me think that this block:
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
in radix_sort_offset
is eating a lot of time. This is also the block that, if you really look at it, does 90% of the work for the whole sort. This code looks like it could be numpy
-ized, which I think would result in quite a performance boost. Unfortunately, I'm not very good with numpy
's more complex features so haven't been able to figure that out. Help would be very appreciated.
I'm currently using itertools.chain.from_iterable
to flatten the buckets
, but if anyone has a faster suggestion I'm sure it would help as well.
Originally, I had a get_byte
function that returned the n
th byte of a number, but inlining the code gave me a huge speed boost so I did it.
Any other comments on the implementation or ways to squeeze out more performance are also appreciated. I want to hear anything and everything you've got.