summaryrefslogtreecommitdiff
path: root/visualize.py
diff options
context:
space:
mode:
Diffstat (limited to 'visualize.py')
-rw-r--r--visualize.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/visualize.py b/visualize.py
new file mode 100644
index 0000000..6746458
--- /dev/null
+++ b/visualize.py
@@ -0,0 +1,76 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+# Some messy visualization functions I hacked together
+
+
+def cross_r(s, fs):
+ i = s.Nres // 2
+
+ # To visualize the diagonals
+ signs = np.array([-1] * i + [1] * i)
+
+ xs = s.rx[:, i, i]
+ ys = s.ry[i, :, i]
+ zs = s.rz[i, i, :]
+
+ for f in fs:
+ plt.plot(xs, f[:, i, i])
+ plt.plot(ys, f[i, :, i])
+ plt.plot(zs, f[i, i, :])
+ plt.plot(signs * np.sqrt(ys**2 + zs**2), np.diagonal(f, axis1=1, axis2=2)[i])
+ plt.plot(signs * np.sqrt(xs**2 + ys**2), np.diagonal(f, axis1=0, axis2=1)[i])
+ plt.plot(signs * np.sqrt(xs**2 + zs**2), np.diagonal(f, axis1=0, axis2=2)[i])
+ plt.show()
+
+
+def height_r(s, f):
+ i = s.Nres // 2
+
+ fig = plt.figure()
+ ax = plt.axes(projection="3d")
+ Xxy, Yxy = np.meshgrid(s.rx[:, i, i], s.ry[i, :, i])
+ Yyz, Zyz = np.meshgrid(s.ry[i, :, i], s.rz[i, i, :])
+ Xxz, Zxz = np.meshgrid(s.rx[:, i, i], s.rz[i, i, :])
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ #ax.plot_surface(Xxy, Yxy, f[:, :, i], cmap="viridis")
+ #ax.plot_surface(Yyz, Zyz, f[i, :, :], cmap="viridis")
+ ax.plot_surface(Xxz, Zxz, f[:, i, :], cmap="viridis")
+ plt.show()
+
+
+def shape_r(s, f, c):
+ ppts_x, ppts_y, ppts_z = [], [], []
+ npts_x, npts_y, npts_z = [], [], []
+ pvals = []
+ nvals = []
+
+ for ix in range(s.Nres):
+ for iy in range(s.Nres):
+ for iz in range(s.Nres):
+ if f[ix, iy, iz] >= c:
+ ppts_x.append(s.rx[ix, iy, iz])
+ ppts_y.append(s.ry[ix, iy, iz])
+ ppts_z.append(s.rz[ix, iy, iz])
+ pvals.append(f[ix, iy, iz])
+ if f[ix, iy, iz] <= -c:
+ npts_x.append(s.rx[ix, iy, iz])
+ npts_y.append(s.ry[ix, iy, iz])
+ npts_z.append(s.rz[ix, iy, iz])
+ nvals.append(f[ix, iy, iz])
+
+ fig = plt.figure()
+ ax = plt.axes(projection="3d")
+ ax.set_xlim(np.amin(s.rx), np.amax(s.rx))
+ ax.set_ylim(np.amin(s.ry), np.amax(s.ry))
+ ax.set_zlim(np.amin(s.rz), np.amax(s.rz))
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ ax.set_zlabel("z")
+
+ ax.scatter(ppts_x, ppts_y, ppts_z, c=pvals, s=10, cmap="Reds_r")
+ ax.scatter(npts_x, npts_y, npts_z, c=nvals, s=10, cmap="Blues")
+ plt.show()
+