문제

I'm having trouble plotting multiple sets of data onto a single 3D scatter plot. What I'm doing is I have a system of three equations and I'm calculating the zeros of the equations using linalg. I'm then plotting each set of zeros I get onto a 3D plot. For one of my parameters, I'm changing it's value and observing how the zeros change from that. I'd like to plot all of the data sets on one 3D scatter plot so it'd be easy to compare how they differ but I keep getting one graph plotted for each data set. Can any of you figure out what I need to fix?

import numpy as np
from numpy import linalg
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

plt.close('all')
#Will be solving the following system of equations:
#sx-(b/r)z=0
#-x+ry+(s-b)z=0
#(1/r)x+y-z=0

r=50.0
b=17.0/4.0
s=[10.0,20.0,7.0,r/b]

color=['r','b','g','y']
markers=['s','o','^','d']

def system(s,b,r,color,m):
#first creates the matrix as an array so the parameters can be changed from outside
#and then coverts array into a matrix
    u_arr=np.array([[s,0,-b/r],[-1,r,s-b],[1/r,1,-1]])
    u_mat=np.matrix(u_arr)

    U_mat=linalg.inv(u_mat)

    #converts matrix into an array and then into a list to manipulate
    x_zeros=np.array(U_mat[0]).reshape(-1).tolist()
    y_zeros=np.array(U_mat[1]).reshape(-1).tolist()
    z_zeros=np.array(U_mat[2]).reshape(-1).tolist()

    zeros=[x_zeros,y_zeros,z_zeros]
    coordinates=['x','y','z']

    print('+'*70)
    print('For s=%1.1f:' % s)
    print('\n')

    for i in range(3):
        print('For the %s direction, the roots are: ' % coordinates[i])
        for j in range(3):
            print(zeros[i][j])
        print('-'*50)

    fig3d=plt.figure()
    ax=Axes3D(fig3d)
    ax.scatter(x_zeros,y_zeros,z_zeros,c=color,marker=m)
    plt.title('Zeros for a Given System of Equations for s=%1.1f' % (s))
    ax.set_xlabel('Zeros in x Direction')
    ax.set_ylabel('Zeros in y Direction')
    ax.set_zlabel('Zeros in z Direction')
    plt.show()

for k in range(len(s)):
    system(s[k],b,r,color[k],markers[k])

Thanks in advance for any help.

도움이 되었습니까?

해결책

You are creating a new axes instance each time system() is called. Instead pass ax as an argument to system

def system(s,b,r,color,m, ax):

        # ...
        ax.scatter(x_zeros,y_zeros,z_zeros,c=color,marker=m)

Then create the axes instance before looping

fig3d=plt.figure()
ax=Axes3D(fig3d)

for k in range(len(s)):
    system(s[k],b,r,color[k],markers[k], ax)

plt.show()

This was all plots are added to ax. You may then want to think about setting the axes labels etc outside of the system() function. Splitting it into two functions, one which sets the plot up and one which creates the required data and plots it.

라이센스 : CC-BY-SA ~와 함께 속성
제휴하지 않습니다 StackOverflow
scroll top