Question

I want to generate all permutations of an n by n list with max value n-1, for example, for n=3, all the possible lists are as follows

[0,0,0]
[0,0,1]
[0,0,2]
[0,1,0]
[0,1,1]
[0,1,2]
[0,2,0]
...
[2,2,2]

I realize this grows very large very quickly (there are n^n permutations). I currently have the following working code which uses recursion

def generatePermutations(allPerms, curPerm, curIndex, n):
    #allPerms is a reference to the full set of permutations
    #curIndex is the index which is currently being changed
    if (curIndex == n - 1):
        for i in range(n):
            curPerm[curIndex] = i
            allPerms.append(list(curPerm))
    else:
        for i in range(n):
            curPerm[curIndex] = i
            #recursively generate all permutations past our curIndex
            generatePermutations(allPerms, curPerm, curIndex + 1, n) 

allPermutations = []
currentPermutation = []
n = 4
for i in range(n):
    currentPermutation.append(0)

generatePermutations(allPermutations, currentPermutation, 0, n)

In trying to find a non-recursive solution I've hit a wall, I'm thinking there would have to be n number of nested for loops, which I can't figure out how to do for arbitrary n. The only ideas I've had are doing some kind of fancy adding of functions containing the loops to a list to be run somehow, or even more absurdly, generating the code programmatically and passing it to an eval call. My gut tells me these are more complicated than necessary. Can anyone think of a solution? thanks!

Was it helpful?

Solution

Simple way, with a library call:

import itertools

def lists(n):
    return itertools.product(xrange(n), repeat=n)

This returns an iterator, rather than a list. You can get a list if you want by calling list on the result.

If you want to do this without foisting the job onto itertools, you can count in base n, incrementing the last digit and carrying whenever you hit n:

def lists(n):
    l = [0]*n
    while True:
        yield l[:]
        for i in reversed(xrange(n)):
            if l[i] != n-1:
                break
            l[i] = 0
        else:
            # All digits were n-1; we're done
            return
        l[i] += 1

OTHER TIPS

You can use itertools.permutations() to handle the whole problem for you, in one go:

from itertools import permutations

allPermutations = list(permutations(range(4))

The documentation includes Python code that details alternative implementations in Python for that function; a version using itertools.product(), for example:

from itertools import product

def permutations(iterable, r=None):
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    for indices in product(range(n), repeat=r):
        if len(set(indices)) == r:
            yield tuple(pool[i] for i in indices)

However, your expected output is just a product of range(3) over 3:

>>> from itertools import product
>>> for p in product(range(3), repeat=3):
...     print p
... 
(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(0, 1, 0)
(0, 1, 1)
(0, 1, 2)
(0, 2, 0)
(0, 2, 1)
(0, 2, 2)
(1, 0, 0)
(1, 0, 1)
(1, 0, 2)
(1, 1, 0)
(1, 1, 1)
(1, 1, 2)
(1, 2, 0)
(1, 2, 1)
(1, 2, 2)
(2, 0, 0)
(2, 0, 1)
(2, 0, 2)
(2, 1, 0)
(2, 1, 1)
(2, 1, 2)
(2, 2, 0)
(2, 2, 1)
(2, 2, 2)

Permutations is a much shorter sequence:

>>> from itertools import permutations
>>> for p in permutations(range(3)):
...     print p
... 
(0, 1, 2)
(0, 2, 1)
(1, 0, 2)
(1, 2, 0)
(2, 0, 1)
(2, 1, 0)
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top