1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
|
import numpy as np
import matplotlib.pyplot as plt
# Some messy visualization functions I hacked together
def cross_r(s, fs):
i = s.Nres // 2
# To visualize the diagonals
signs = np.array([-1] * i + [1] * i)
xs = s.rx[:, i, i]
ys = s.ry[i, :, i]
zs = s.rz[i, i, :]
for f in fs:
plt.plot(xs, f[:, i, i])
plt.plot(ys, f[i, :, i])
plt.plot(zs, f[i, i, :])
plt.plot(signs * np.sqrt(ys**2 + zs**2), np.diagonal(f, axis1=1, axis2=2)[i])
plt.plot(signs * np.sqrt(xs**2 + ys**2), np.diagonal(f, axis1=0, axis2=1)[i])
plt.plot(signs * np.sqrt(xs**2 + zs**2), np.diagonal(f, axis1=0, axis2=2)[i])
plt.show()
def height_r(s, f):
i = s.Nres // 2
fig = plt.figure()
ax = plt.axes(projection="3d")
Xxy, Yxy = np.meshgrid(s.rx[:, i, i], s.ry[i, :, i])
Yyz, Zyz = np.meshgrid(s.ry[i, :, i], s.rz[i, i, :])
Xxz, Zxz = np.meshgrid(s.rx[:, i, i], s.rz[i, i, :])
ax.set_xlabel("x")
ax.set_ylabel("y")
#ax.plot_surface(Xxy, Yxy, f[:, :, i], cmap="viridis")
#ax.plot_surface(Yyz, Zyz, f[i, :, :], cmap="viridis")
ax.plot_surface(Xxz, Zxz, f[:, i, :], cmap="viridis")
plt.show()
def shape_r(s, f, c):
ppts_x, ppts_y, ppts_z = [], [], []
npts_x, npts_y, npts_z = [], [], []
pvals = []
nvals = []
for ix in range(s.Nres):
for iy in range(s.Nres):
for iz in range(s.Nres):
if f[ix, iy, iz] >= c:
ppts_x.append(s.rx[ix, iy, iz])
ppts_y.append(s.ry[ix, iy, iz])
ppts_z.append(s.rz[ix, iy, iz])
pvals.append(f[ix, iy, iz])
if f[ix, iy, iz] <= -c:
npts_x.append(s.rx[ix, iy, iz])
npts_y.append(s.ry[ix, iy, iz])
npts_z.append(s.rz[ix, iy, iz])
nvals.append(f[ix, iy, iz])
fig = plt.figure()
ax = plt.axes(projection="3d")
ax.set_xlim(np.amin(s.rx), np.amax(s.rx))
ax.set_ylim(np.amin(s.ry), np.amax(s.ry))
ax.set_zlim(np.amin(s.rz), np.amax(s.rz))
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.scatter(ppts_x, ppts_y, ppts_z, c=pvals, s=10, cmap="Reds_r")
ax.scatter(npts_x, npts_y, npts_z, c=nvals, s=10, cmap="Blues")
plt.show()
|