summaryrefslogtreecommitdiff
path: root/visualize.py
blob: 674645867c7181bc5c53e759ad25899b492886f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import matplotlib.pyplot as plt


# Some messy visualization functions I hacked together


def cross_r(s, fs):
    i = s.Nres // 2

    # To visualize the diagonals
    signs = np.array([-1] * i + [1] * i)

    xs = s.rx[:, i, i]
    ys = s.ry[i, :, i]
    zs = s.rz[i, i, :]

    for f in fs:
        plt.plot(xs, f[:, i, i])
        plt.plot(ys, f[i, :, i])
        plt.plot(zs, f[i, i, :])
        plt.plot(signs * np.sqrt(ys**2 + zs**2), np.diagonal(f, axis1=1, axis2=2)[i])
        plt.plot(signs * np.sqrt(xs**2 + ys**2), np.diagonal(f, axis1=0, axis2=1)[i])
        plt.plot(signs * np.sqrt(xs**2 + zs**2), np.diagonal(f, axis1=0, axis2=2)[i])
    plt.show()


def height_r(s, f):
    i = s.Nres // 2

    fig = plt.figure()
    ax = plt.axes(projection="3d")
    Xxy, Yxy = np.meshgrid(s.rx[:, i, i], s.ry[i, :, i])
    Yyz, Zyz = np.meshgrid(s.ry[i, :, i], s.rz[i, i, :])
    Xxz, Zxz = np.meshgrid(s.rx[:, i, i], s.rz[i, i, :])
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    #ax.plot_surface(Xxy, Yxy, f[:, :, i], cmap="viridis")
    #ax.plot_surface(Yyz, Zyz, f[i, :, :], cmap="viridis")
    ax.plot_surface(Xxz, Zxz, f[:, i, :], cmap="viridis")
    plt.show()


def shape_r(s, f, c):
    ppts_x, ppts_y, ppts_z = [], [], []
    npts_x, npts_y, npts_z = [], [], []
    pvals = []
    nvals = []

    for ix in range(s.Nres):
        for iy in range(s.Nres):
            for iz in range(s.Nres):
                if f[ix, iy, iz] >= c:
                    ppts_x.append(s.rx[ix, iy, iz])
                    ppts_y.append(s.ry[ix, iy, iz])
                    ppts_z.append(s.rz[ix, iy, iz])
                    pvals.append(f[ix, iy, iz])
                if f[ix, iy, iz] <= -c:
                    npts_x.append(s.rx[ix, iy, iz])
                    npts_y.append(s.ry[ix, iy, iz])
                    npts_z.append(s.rz[ix, iy, iz])
                    nvals.append(f[ix, iy, iz])

    fig = plt.figure()
    ax = plt.axes(projection="3d")
    ax.set_xlim(np.amin(s.rx), np.amax(s.rx))
    ax.set_ylim(np.amin(s.ry), np.amax(s.ry))
    ax.set_zlim(np.amin(s.rz), np.amax(s.rz))
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    ax.scatter(ppts_x, ppts_y, ppts_z, c=pvals, s=10, cmap="Reds_r")
    ax.scatter(npts_x, npts_y, npts_z, c=nvals, s=10, cmap="Blues")
    plt.show()