diff options
Diffstat (limited to 'visualize.py')
-rw-r--r-- | visualize.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..6746458 --- /dev/null +++ b/visualize.py @@ -0,0 +1,76 @@ +import numpy as np +import matplotlib.pyplot as plt + + +# Some messy visualization functions I hacked together + + +def cross_r(s, fs): + i = s.Nres // 2 + + # To visualize the diagonals + signs = np.array([-1] * i + [1] * i) + + xs = s.rx[:, i, i] + ys = s.ry[i, :, i] + zs = s.rz[i, i, :] + + for f in fs: + plt.plot(xs, f[:, i, i]) + plt.plot(ys, f[i, :, i]) + plt.plot(zs, f[i, i, :]) + plt.plot(signs * np.sqrt(ys**2 + zs**2), np.diagonal(f, axis1=1, axis2=2)[i]) + plt.plot(signs * np.sqrt(xs**2 + ys**2), np.diagonal(f, axis1=0, axis2=1)[i]) + plt.plot(signs * np.sqrt(xs**2 + zs**2), np.diagonal(f, axis1=0, axis2=2)[i]) + plt.show() + + +def height_r(s, f): + i = s.Nres // 2 + + fig = plt.figure() + ax = plt.axes(projection="3d") + Xxy, Yxy = np.meshgrid(s.rx[:, i, i], s.ry[i, :, i]) + Yyz, Zyz = np.meshgrid(s.ry[i, :, i], s.rz[i, i, :]) + Xxz, Zxz = np.meshgrid(s.rx[:, i, i], s.rz[i, i, :]) + ax.set_xlabel("x") + ax.set_ylabel("y") + #ax.plot_surface(Xxy, Yxy, f[:, :, i], cmap="viridis") + #ax.plot_surface(Yyz, Zyz, f[i, :, :], cmap="viridis") + ax.plot_surface(Xxz, Zxz, f[:, i, :], cmap="viridis") + plt.show() + + +def shape_r(s, f, c): + ppts_x, ppts_y, ppts_z = [], [], [] + npts_x, npts_y, npts_z = [], [], [] + pvals = [] + nvals = [] + + for ix in range(s.Nres): + for iy in range(s.Nres): + for iz in range(s.Nres): + if f[ix, iy, iz] >= c: + ppts_x.append(s.rx[ix, iy, iz]) + ppts_y.append(s.ry[ix, iy, iz]) + ppts_z.append(s.rz[ix, iy, iz]) + pvals.append(f[ix, iy, iz]) + if f[ix, iy, iz] <= -c: + npts_x.append(s.rx[ix, iy, iz]) + npts_y.append(s.ry[ix, iy, iz]) + npts_z.append(s.rz[ix, iy, iz]) + nvals.append(f[ix, iy, iz]) + + fig = plt.figure() + ax = plt.axes(projection="3d") + ax.set_xlim(np.amin(s.rx), np.amax(s.rx)) + ax.set_ylim(np.amin(s.ry), np.amax(s.ry)) + ax.set_zlim(np.amin(s.rz), np.amax(s.rz)) + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + + ax.scatter(ppts_x, ppts_y, ppts_z, c=pvals, s=10, cmap="Reds_r") + ax.scatter(npts_x, npts_y, npts_z, c=nvals, s=10, cmap="Blues") + plt.show() + |