Pregunta

I am updating a 3d scatter plot with every iteration of a loop. When the plot is redrawn, the gridlines "go through" or "cover" the points, which makes my data more difficult to visualize. If I build a single 3d plot (no loop updating) this does not happen. The code below demonstrates the simplest case:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2])
plt.draw()

for i in range(0, 20):
    time.sleep(3)   #make changes more apparent/easy to see

    Y = np.random.rand(100, 3)*5
    ax.cla()    
    ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2])
    plt.draw()

Has anyone else encountered this problem?

¿Fue útil?

Solución

It looks like MaxNoe is right in the sense that the problem is in the ax.cla()or plt.cla() call. In fact it seems it is something like a known issue.

Then there is a problem, since the clear axes method doesn't work in 3D plots and for 3D scatters there is no clean way to change the coordinates of the data points (a la sc.set_data(new_values)), as discussed in this mail list (I didn't find anything more recent).

In the mail list, however, Ben Roon points to a workaround that might be useful for you, too.

Workaround:

You need to set the new coordinates of the datapoints in the internal _ofsets3d variable of the Line3DCollectionobject returned by the scatter function.

Your example adapted would look like:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(X[:, 0], X[:, 1], X[:, 2])
fig.show()

for i in range(0, 20):
    plt.pause(1)

    Y = np.random.rand(100, 3)*5

    sc._offsets3d = (Y[:,0], Y[:,1], Y[:,2])
    plt.draw()

Otros consejos

I could narrow it down to the use of cla():

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))

ax.plot_surface(x,y, x**2+y**2)
fig.savefig("fig_a.png")

ax.cla()
ax.plot_surface(x,y, x**2+y**2)

fig.savefig("fig_b.png")

these are the resulting plots: fig_a fig_b

This is but a workaround, as it does not resolve the issue with ax.cla() pointed out by MaxNoe. It is also not particularly pretty since it clears the entire figure, however it does the desired task:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig1 = plt.figure()
ax1 = fig1.add_subplot(111, projection='3d')

x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))

ax1.plot_surface(x,y, x**2+y**2)
fig1.savefig("fig_a.png")

fig1.clf()
ax1 = fig1.add_subplot(111, projection='3d')
ax1.plot_surface(x,y, x**2+y**2)

fig1.savefig("fig_b.png")

I'd suggest using ax = fig.gca(projection='3d') instead of ax = fig.add_subplot(111, projection='3d') .

Licenciado bajo: CC-BY-SA con atribución
No afiliado a StackOverflow
scroll top