Question

I am updating a 3d scatter plot with every iteration of a loop. When the plot is redrawn, the gridlines "go through" or "cover" the points, which makes my data more difficult to visualize. If I build a single 3d plot (no loop updating) this does not happen. The code below demonstrates the simplest case:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2])
plt.draw()

for i in range(0, 20):
    time.sleep(3)   #make changes more apparent/easy to see

    Y = np.random.rand(100, 3)*5
    ax.cla()    
    ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2])
    plt.draw()

Has anyone else encountered this problem?

Était-ce utile?

La solution

It looks like MaxNoe is right in the sense that the problem is in the ax.cla()or plt.cla() call. In fact it seems it is something like a known issue.

Then there is a problem, since the clear axes method doesn't work in 3D plots and for 3D scatters there is no clean way to change the coordinates of the data points (a la sc.set_data(new_values)), as discussed in this mail list (I didn't find anything more recent).

In the mail list, however, Ben Roon points to a workaround that might be useful for you, too.

Workaround:

You need to set the new coordinates of the datapoints in the internal _ofsets3d variable of the Line3DCollectionobject returned by the scatter function.

Your example adapted would look like:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(X[:, 0], X[:, 1], X[:, 2])
fig.show()

for i in range(0, 20):
    plt.pause(1)

    Y = np.random.rand(100, 3)*5

    sc._offsets3d = (Y[:,0], Y[:,1], Y[:,2])
    plt.draw()

Autres conseils

I could narrow it down to the use of cla():

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))

ax.plot_surface(x,y, x**2+y**2)
fig.savefig("fig_a.png")

ax.cla()
ax.plot_surface(x,y, x**2+y**2)

fig.savefig("fig_b.png")

these are the resulting plots: fig_a fig_b

This is but a workaround, as it does not resolve the issue with ax.cla() pointed out by MaxNoe. It is also not particularly pretty since it clears the entire figure, however it does the desired task:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig1 = plt.figure()
ax1 = fig1.add_subplot(111, projection='3d')

x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))

ax1.plot_surface(x,y, x**2+y**2)
fig1.savefig("fig_a.png")

fig1.clf()
ax1 = fig1.add_subplot(111, projection='3d')
ax1.plot_surface(x,y, x**2+y**2)

fig1.savefig("fig_b.png")

I'd suggest using ax = fig.gca(projection='3d') instead of ax = fig.add_subplot(111, projection='3d') .

Licencié sous: CC-BY-SA avec attribution
Non affilié à StackOverflow
scroll top