Question

I recently wrote some code that I thought very inefficient, but since it included only a few values, I accepted it. However, I'm still interested at a better algorithm for the following:

  1. A list of X objects, each of them are assigned a "weight"
  2. Sum up the weights
  3. Generate a random number from 0 to the sum
  4. Iterate through the objects, subtracting their weight from the sum until the sum is non-positive
  5. Remove the object from the list, and then add it to the end of the new list

Items 2,4, and 5 all take n time, and so it is an O(n^2) algorithm.

Can this be improved?

As an example of a weighted shuffle, an element has a greater chance at being at the front with a higher weight.

Example (I'll generate random numbers to make it real):

6 objects with weights 6,5,4,3,2,1; Sum is 21

I picked 19: 19-6-5-4-3-2 = -1, thus 2 goes in the first position, weights are now 6,5,4,3,1; Sum is 19

I picked 16: 16-6-5-4-3 = -2, thus 3 goes in the second position, weights are now 6,5,4,1; Sum is 16

I picked 3: 3-6 = -3, thus 6 goes in the third position, weights are now 5,4,1; Sum is 10

I picked 8: 8-5-4 = -1, thus 4 goes in the fourth position, weights are now 5,1; Sum is 6

I picked 5: 5-5=0, thus 5 goes in the fifth position, weights are now 1; Sum is 1

I picked 1: 1-1=0, thus 1 goes in the last position, I have no more weights, I finish

Was it helpful?

Solution

This can be implemented in O(n log(n)) using a tree.

First, create the tree, keeping in each node the cumulative sum of all the descendent nodes on the right and on the left of each node.

To sample an item, recursively sample from the root node, using the cumulative sums to decide if you return the current node, a node from the left or a node from the right. Every time you sample a node, set its weight to zero and also update the parent nodes.

This is my implementation in Python:

import random

def weigthed_shuffle(items, weights):
    if len(items) != len(weights):
        raise ValueError("Unequal lengths")

    n = len(items)
    nodes = [None for _ in range(n)]

    def left_index(i):
        return 2 * i + 1

    def right_index(i):
        return 2 * i + 2

    def total_weight(i=0):
        if i >= n:
            return 0
        this_weigth = weights[i]
        if this_weigth <= 0:
            raise ValueError("Weigth can't be zero or negative")
        left_weigth = total_weight(left_index(i))
        right_weigth = total_weight(right_index(i))
        nodes[i] = [this_weigth, left_weigth, right_weigth]
        return this_weigth + left_weigth + right_weigth

    def sample(i=0):
        this_w, left_w, right_w = nodes[i]
        total = this_w + left_w + right_w
        r = total * random.random()
        if r < this_w:
            nodes[i][0] = 0
            return i
        elif r < this_w + left_w:
            chosen = sample(left_index(i))
            nodes[i][1] -= weights[chosen]
            return chosen
        else:
            chosen = sample(right_index(i))
            nodes[i][2] -= weights[chosen]
            return chosen

    total_weight() # build nodes tree

    return (items[sample()] for _ in range(n - 1))

Usage:

In [2]: items = list(range(10))
   ...: weights = list(range(10, 0, -1))
   ...:

In [3]: for _ in range(10):
   ...:     print(list(weigthed_shuffle(items, weights)))
   ...:
[5, 0, 8, 6, 7, 2, 3, 1, 4]
[1, 2, 5, 7, 3, 6, 9, 0, 4]
[1, 0, 2, 6, 8, 3, 7, 5, 4]
[4, 6, 8, 1, 2, 0, 3, 9, 7]
[3, 5, 1, 0, 4, 7, 2, 6, 8]
[3, 7, 1, 2, 0, 5, 6, 4, 8]
[1, 4, 8, 2, 6, 3, 0, 9, 5]
[3, 5, 0, 4, 2, 6, 1, 8, 9]
[6, 3, 5, 0, 1, 2, 4, 8, 7]
[4, 1, 2, 0, 3, 8, 6, 5, 7]

weigthed_shuffle is a generator, so you can sample the top k items efficiently. If you want the shuffle the whole array, just iterate over the generator until exhaustion (using the list function).

UPDATE:

Weighted Random Sampling (2005; Efraimidis, Spirakis) provides a very elegant algorithm for this. The implementation is super simple, and also runs in O(n log(n)):

def weigthed_shuffle(items, weights):
    order = sorted(range(len(items)), key=lambda i: -random.random() ** (1.0 / weights[i]))
    return [items[i] for i in order]

OTHER TIPS

EDIT: This answer doesn't interpret the weights in the way that would be expected. I.e. an item with weight 2 isn't twice as likely to be first as one with weight 1.

One way to shuffle a list is to assign random numbers to each element in the list and sort by those numbers. We can extend that idea, we just have to pick weighted random numbers. For example, you could use random() * weight. Different choices will produce different distributions.

In something like Python, this should be as simple as:

items.sort(key = lambda item: random.random() * item.weight)

Be careful that you don't evaluate the keys more then once, as they'll end up with different values.

First, lets work from that the weight of a given element in the list to be sorted is constant. It is not going to change between iterations. If it does, then... well, thats a bigger problem.

For illustration lets use a deck of cards where we want to weight the face cards to the front. weight(card) = card.rank. Summing these, if we don't know the distribution of weights is indeed O(n) once.

These elements are stored in a sorted structure such as a modification on an indexable skip list such that all of the indexes of the levels can be accessed from a given node:

   1                               10
 o---> o---------------------------------------------------------> o    Top level
   1           3              2                    5
 o---> o---------------> o---------> o---------------------------> o    Level 3
   1        2        1        2                    5
 o---> o---------> o---> o---------> o---------------------------> o    Level 2
   1     1     1     1     1     1     1     1     1     1     1 
 o---> o---> o---> o---> o---> o---> o---> o---> o---> o---> o---> o    Bottom level

Head  1st   2nd   3rd   4th   5th   6th   7th   8th   9th   10th  NIL
      Node  Node  Node  Node  Node  Node  Node  Node  Node  Node

However in this instance, each node also 'takes up' as much room as its weight.

Now, when looking up an card in this list one can access its position in the list in O(log n) time and remove it from the associated lists in O(1) time. Ok, it might not be O(1), it might be O(log log n) time (I'd have to think about this much more). Removing the 6th node in the above example would involve updating all four levels - and those four levels are independent of how many elements there are in the list (depending on how you implement the levels).

Since the weight of an element is constant, one can simply do sum -= weight(removed) without having to traverse the structure again.

And thus, you've got a one time cost of O(n) and a lookup value of O(log n) and a remove from list cost of O(1). This becomes O(n) + n * O(log n) + n * O(1) which gives you an overall performance of O(n log n).


Lets look at this with cards, because thats what I used above.

      10
top 3 -----------------------> 4d
                                .
       3             7          .
    2 ---------> 2d ---------> 4d
                  .             .
       1      2   .  3      4   .
bot 1 --> Ad --> 2d --> 3d --> 4d

This is a really small deck with only 4 cards in it. It should be easy to see how this can be extended. With 52 cards an ideal structure would have 6 levels (log2(52) ~= 6), though if you dig into the skip lists even that could be reduced to a smaller number.

The sum of all the weights is 10. So you get a random number from [1 .. 10) and its 4 You walk the skip list to find the item that is at ceiling(4). Since 4 is less than 10, you move from the top level to the second level. Four is greater than 3, so now we're at the 2 of diamonds. 4 is less than 3 + 7, so we move down to the bottom level and 4 is less than 3 + 3, so we've got a 3 of diamonds.

After removing the 3 of diamonds from the structure, the structure now looks like:

       7
top 3 ----------------> 4d
                         .
       3             4   .
    2 ---------> 2d --> 4d
                  .      .
       1      2   .  4   .
bot 1 --> Ad --> 2d --> 4d

You will note that the nodes take up an amount of 'space' proportional to their weight in the structure. This allows for the weighted selection.

As this approximates a balanced binary tree, lookup in this doesn't need to walk the bottom layer (which would be O(n)) and instead going from the top allows you to rapidly skip down the structure to find about what you are looking for.

Much of this could instead be done with some sort of balanced tree. The problem there is the rebalancing of the structure when a node is removed gets confusing since this isn't a classic tree structure and the housekeeping to remember that the 4 of diamonds is now moved from positions [6 7 8 9] to [3 4 5 6] may cost more than the benefits of the tree structure.

However, while the skip list approximates a binary tree in its ability to skip down the list in O(log n) time, it has the simplicity of working with a linked list instead.

This isn't to say that it is easy to do all this (you still need to keep tabs on all the links you need to modify when you remove an element), but it means only updating however many levels you have and their links rather than everything to the right on the proper tree structure.

Licensed under: CC-BY-SA with attribution
scroll top