import numpy as np import matplotlib.pyplot as plt # Some messy visualization functions I hacked together def cross_r(s, fs): i = s.Nres // 2 # To visualize the diagonals signs = np.array([-1] * i + [1] * i) xs = s.rx[:, i, i] ys = s.ry[i, :, i] zs = s.rz[i, i, :] for f in fs: plt.plot(xs, f[:, i, i]) plt.plot(ys, f[i, :, i]) plt.plot(zs, f[i, i, :]) plt.plot(signs * np.sqrt(ys**2 + zs**2), np.diagonal(f, axis1=1, axis2=2)[i]) plt.plot(signs * np.sqrt(xs**2 + ys**2), np.diagonal(f, axis1=0, axis2=1)[i]) plt.plot(signs * np.sqrt(xs**2 + zs**2), np.diagonal(f, axis1=0, axis2=2)[i]) plt.show() def height_r(s, f): i = s.Nres // 2 fig = plt.figure() ax = plt.axes(projection="3d") Xxy, Yxy = np.meshgrid(s.rx[:, i, i], s.ry[i, :, i]) Yyz, Zyz = np.meshgrid(s.ry[i, :, i], s.rz[i, i, :]) Xxz, Zxz = np.meshgrid(s.rx[:, i, i], s.rz[i, i, :]) ax.set_xlabel("x") ax.set_ylabel("y") #ax.plot_surface(Xxy, Yxy, f[:, :, i], cmap="viridis") #ax.plot_surface(Yyz, Zyz, f[i, :, :], cmap="viridis") ax.plot_surface(Xxz, Zxz, f[:, i, :], cmap="viridis") plt.show() def shape_r(s, f, c): ppts_x, ppts_y, ppts_z = [], [], [] npts_x, npts_y, npts_z = [], [], [] pvals = [] nvals = [] for ix in range(s.Nres): for iy in range(s.Nres): for iz in range(s.Nres): if f[ix, iy, iz] >= c: ppts_x.append(s.rx[ix, iy, iz]) ppts_y.append(s.ry[ix, iy, iz]) ppts_z.append(s.rz[ix, iy, iz]) pvals.append(f[ix, iy, iz]) if f[ix, iy, iz] <= -c: npts_x.append(s.rx[ix, iy, iz]) npts_y.append(s.ry[ix, iy, iz]) npts_z.append(s.rz[ix, iy, iz]) nvals.append(f[ix, iy, iz]) fig = plt.figure() ax = plt.axes(projection="3d") ax.set_xlim(np.amin(s.rx), np.amax(s.rx)) ax.set_ylim(np.amin(s.ry), np.amax(s.ry)) ax.set_zlim(np.amin(s.rz), np.amax(s.rz)) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") ax.scatter(ppts_x, ppts_y, ppts_z, c=pvals, s=10, cmap="Reds_r") ax.scatter(npts_x, npts_y, npts_z, c=nvals, s=10, cmap="Blues") plt.show()