How can I fix and optimize this very simple piece of "Game of Life" code by taking advantage of NumPy's functionality?

StackOverflow https://stackoverflow.com/questions/22802708

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
from random import randint

arraySize = 50
Z = np.array([[randint(0, 1) for x in range(arraySize)] for y in range(arraySize)])


def computeNeighbours(Z):
    rows, cols = len(Z), len(Z[0])
    N = np.zeros(np.shape(Z))

    for x in range(rows):
        for y in range(cols):
            Q = [q for q in [x-1, x, x+1] if ((q >= 0) and (q < cols))]
            R = [r for r in [y-1, y, y+1] if ((r >= 0) and (r < rows))]
            S = [Z[q][r] for q in Q for r in R if (q, r) != (x, y)]
            N[x][y] = sum(S)

    return N

def iterate(Z):
    rows, cols = len(Z), len(Z[0])
    N = computeNeighbours(Z)

    for x in range(rows):
        for y in range(cols):
            if Z[x][y] == 1:
                if (N[x][y] < 2) or (N[x][y] > 3):
                    Z[x][y] = 0
            else:
                if (N[x][y] == 3):
                    Z[x][y] = 1

    return Z

fig = plt.figure()

Zs = [Z]
ims = []

for i in range(0, 100):
    im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
    ims.append([im])
    Zs.append(iterate(Z))

ani = animation.ArtistAnimation(fig, ims, interval=250, blit=True)
plt.show()

At first, I wrote a simple implementation of "Game of Life" using nothing but standard Python tools. I plotted it, it ran correctly, and it animated correctly.

I tried converting the arrays into NumPy arrays next, which is where the code is as it stands right now. However, the animation no longer seems to work, and I haven't been able to figure out why. (Update: this bug has been fixed!)

Next, I am interested in taking advantage of NumPy in order to optimize my code. What I have done so far is convert the Python arrays I was using into NumPy arrays. While I am sure there is some gain in performance, it is not recognizable.

I am interested in knowing what sequence of optimizations one would perform for this sort of an application, so that I can get a handle on how to use NumPy's power for my current project, which is simply a (perhaps) three dimensional, cellular automaton with many rules.


The following changes to the code fixed the animation error:

1) Change iterate to create a deep copy of Z, and then make changes to that deep copy. The new iterate:

def iterate(Z):
    Zprime = Z.copy()
    rows, cols = len(Zprime), len(Zprime[0])
    N = computeNeighbours(Zprime)

    for x in range(rows):
        for y in range(cols):
            if Zprime[x][y] == 1:
                if (N[x][y] < 2) or (N[x][y] > 3):
                    Zprime[x][y] = 0
            else:
                if (N[x][y] == 3):
                    Zprime[x][y] = 1

    return Zprime

2) As a consequence of 1, change this piece of code:

for i in range(0, 100):
    im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
    ims.append([im])
    Zs.append(iterate(Z))

to:

for i in range(0, 100):
    im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
    ims.append([im])
    Zs.append(iterate(Zs[len(Zs)-1]))
有帮助吗?

解决方案

The following:

for x in range(rows):
        for y in range(cols):
            if Z[x][y] == 1:
                if (N[x][y] < 2) or (N[x][y] > 3):
                    Z[x][y] = 0
            else:
                if (N[x][y] == 3):
                    Z[x][y] = 1

could be replaced by:

set_zero_idxs = (Z==1) & ((N<2) | (N>3))
set_one_idxs = (Z!=1) & (N==3)
Z[set_zero_idxs] = 0
Z[set_one_idxs] = 1

This will take more operations than the loop that you have, but I would expect it to be faster.

EDIT:

So I have just benchmarked the two solutions, somewhat unsurprisingly the numpy version is 180 times faster:

In [49]: %timeit no_loop(z,n)
1000 loops, best of 3: 177 us per loop

In [50]: %timeit loop(z,n)
10 loops, best of 3: 31.2 ms per loop

EDIT2:

I think that this loop:

for x in range(rows):
        for y in range(cols):
            Q = [q for q in [x-1, x, x+1] if ((q >= 0) and (q < cols))]
            R = [r for r in [y-1, y, y+1] if ((r >= 0) and (r < rows))]
            S = [Z[q][r] for q in Q for r in R if (q, r) != (x, y)]
            N[x][y] = sum(S)

can be replaced by:

N = np.roll(Z,1,axis=1) + np.roll(Z,-1,axis=1) + np.roll(Z,1,axis=0) + np.roll(Z,-1,axis=0)

Here there is an implicit assumption that the array does not have bounds and that x[-1] is next to x[0]. If this is a problem, you could add a buffer of zeros around your array with:

shape = Z.shape
new_shape = (shape[0]+2,shape[1]+2)
b_z = np.zeros(new_shape)
b_z[1:-1,1:-1] = Z
b_n = np.roll(b_z,1,axis=1) + np.roll(b_z,-1,axis=1) + np.roll(b_z,1,axis=0) + np.roll(b_z,-1,axis=0)
N = b_n[1:-1,1:-1]

And for a benchmark:

In [4]: %timeit computeNeighbours(z)
10 loops, best of 3: 140 ms per loop 

In [5]: %timeit noloop_computeNeighbours(z)
10000 loops, best of 3: 133 us per loop

In [6]: %timeit noloop_with_buffer_computeNeighbours(z)
10000 loops, best of 3: 170 us per loop

So just a small improvement of a factor of 1052. Hooray for Numpy!

许可以下: CC-BY-SA归因
不隶属于 StackOverflow
scroll top