Matplotlib 3D Plots (1)

  • Matplotlib으로 3D Plot을 할 수 있습니다.
  • 많은 분들이 알고 있는 사실이지만 적극적으로 쓰이지 않습니다.
  • 막상 쓰려면 너무 낯설기도 하고 잘 모르기도 하기 때문입니다.

Reference

matplotlib tutorial: The mplot3d Toolkit
numpy.meshgrid

1. 예제 데이터

1.1. 공식 예제 데이터

  • 위 그림이 익숙한 분들이 많으실텐데, Matplotlib이 가지고 있는 3D 예제 데이터입니다.

  • mpl_toolkits.mplot3d.axes3d 모듈의 get_test_data()명령으로 불러올 수 있습니다.

  • 매개변수로 float을 받습니다. 정밀도를 나타내며 숫자가 작을수록 정밀합니다.

  • 이상하게도 공식 홈페이지에 누락되어 있습니다.

  • 3D Plot 공식 예제에서는 보통 0.5를 넣습니다.

    1
    2
    3
    4
    5
    6
    7
    from mpl_toolkits.mplot3d.axes3d import get_test_data

    X, Y, Z = get_test_data(0.5)

    print(f"X.shape={X.shape}")
    print(f"Y.shape={Y.shape}")
    print(f"Z.shape={Z.shape}")
  • 실행 결과 : meshgrid형식의 12 x 12 데이터가 생성됩니다.

    1
    2
    3
    X.shape=(12, 12)
    Y.shape=(12, 12)
    Z.shape=(12, 12)
  • get_test_data()에 들어가는 숫자를 바꿔봅니다.

  • 숫자에 반비례하여 해상도를 바꾸는 효과가 있습니다.

    1
    2
    3
    4
    5
    6
    7
    fig, axs = plt.subplots(ncols=3, figsize=(10, 4), subplot_kw={"projection":"3d"}, constrained_layout=True)

    for ax, d in zip(axs, [1, 0.5, 0.1]):
    X, Y, Z = get_test_data(d)
    dim = X.shape[0]
    ax.plot_wireframe(X, Y, Z)
    ax.set_title(f"get_test_data({d}): {dim}x{dim}", fontsize="x-large", color="gray", fontweight="bold")


1.2. 이번 예제 데이터

  • 그러나 이 글에서 우리는 저 예제 데이터를 사용하지 않을 것입니다.
  • 기본 예제 데이터를 사용한 그림을 보고 싶으시면 공식 홈페이지를 보시기 바랍니다.
  • 이 글에서는 제가 현업에서 얻은 예제 데이터 두 가지를 사용하겠습니다.
  • 보안을 위해 column name은 제거했습니다.
    1
    2
    data1 = pd.read_csv("data1.csv")
    data1.head()

1
2
data2 = pd.read_csv("data2.csv")
data2.head()


  • data1과 data2에는 Y1과 Y2가 있습니다.
  • data1은 X, Y1을 독립변수로 Y2과 Z를 얻은 것이고
  • data2는 X, Y2를 독립변수로 Y1과 Z를 얻은 것입니다.
  • 다시 말해 data1은 X와 Y1이, data2는 X와 Y2직교좌표계에 놓여 있습니다.
  • 그리고 무의미한 데이터는 삭제되어 있습니다.

2. 2D Visualization

  • 3D 데이터는 2D 이미지로 표현할 수 있습니다.
  • 여러가지 방식을 사용합니다.

2.1. ax.scatter()

  • 가장 기본적인 scatter plot을 사용합니다.

  • color로 Z를 사용하고 inferno colormap으로 데이터를 표기합니다.

    1
    2
    3
    4
    5
    6
    7
    8
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    ax.scatter(data["X"], data[ydata], c=data["Z"], cmap="inferno", s=10)
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)


  • 왼쪽 X vs Y1은 사다리꼴 모양,

  • 오른쪽 X vs Y2는 휘어진 활 모양입니다.

  • scatter plot은 그려보기 가장 좋지만 데이터 사이 거리가 멀면 듬성듬성 구멍이 보일 수도 있습니다.

2.2. ax.imshow()

  • xy 평면에 놓인 데이터는 이미지처럼 다룰 수 있습니다.

  • ax.imshow()명령을 쓰려면 pandas의 pivot_table 명령을 이용해 데이터 형태를 바꾸어야 합니다.

  • 가로에 X 데이터를, 세로에 Y 데이터를 놓기 위해 X를 column, Y를 index로 지정합니다.

    1
    data1.pivot_table("Z", "Y1", "X")


  • numerical data인 X와 Y를 categorical처럼 사용합니다.

  • 데이터가 직교하지 않는 경우 매우 많은 구멍이 뚫립니다.
    (ex. data1에서 X와 Y2로 pivot_table()을 하는 경우)

  • X, Y, Z 컬럼이 따로 있는 방식을 long format,

  • 가로와 세로에 X와 Y가 놓이고 Z가 행렬처럼 놓이는 방식을 wide format이라고 합니다.

  • long format을 입력받는 명령이 있고, wide format을 입력받는 명령이 있습니다.

  • 데이터를 이미지인양 표현합니다.
  • 가로와 세로축이 index이기 때문에 extent매개변수로 실제 값과 매핑을 시켜야 합니다.
  • 이 작업이 없으면 값이 아니라 index로 출력됩니다.
  • 그리고 하나 더, 이미지는 다른 plot과 달리 y값이 위에서부터 아래로 작아집니다.
  • ax.invert_yaxis()를 해서 상하를 뒤집어야 합니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    ax.imshow(data.pivot_table("Z", ydata, "X"), cmap="inferno",
    extent=[data["X"].min(), data["X"].max(), # 가로에 X 데이터
    data[ydata].max(), data[ydata].min()]) # 세로에 Y 데이터
    ax.invert_yaxis() # 그림 뒤집기
    ax.set_aspect((data["X"].max() - data["X"].min())/(data[ydata].max() - data[ydata].min()))
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)

2.3. ax.pcolor(), ax.pcolormesh(), ax.pcolorfast()

matplotlib: Difference between pcolor() and pcolormesh()
matplotlib.axes.Axes.pcolorfast

  • 이미지처럼 출력할 수 있는 명령은 여러 가지가 있습니다.
  • ax.pcolor(), ax.pcolormesh(), ax.pcolorfast()가 그들입니다.
  • ax.imshow()는 pixel의 가로세로 길이가 같아야 하지만 이들은 그렇지 않습니다.
  • ax.pcolor()PolyCollection을 return하고
  • ax.pcolormesh()QuadMesh를 return합니다.
  • ax.pcolormesh()가 빠르고 Gouraud shading을 지원하지만 ax.pcolor()는 masked array를 지원합니다.
  • ax.pcolorfast()는 Agg backend를 이용해 ax.pcolor()를 지원하는 실험적인 기능입니다.
  • 예제 코드는 ax.pcolor()만 사용합니다.

  • 실행해보면 아시겠지만 다른 명령도 결과가 똑같습니다.

    1
    2
    3
    4
    5
    6
    7
    8
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    ax.pcolor(data.pivot_table("Z", ydata, "X"), cmap="inferno")
    ax.set_xlabel("X (index)", fontdict=fontlabel)
    ax.set_ylabel(ydata + " (index)", fontdict=fontlabel)

  • X, Y 좌표가 실제 데이터가 아니라 index입니다.

  • index를 좌표로 바꿉니다.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    data_pt = data.pivot_table("Z", ydata, "X")
    ax.pcolor(data_pt, cmap="inferno")
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)

    # index 중에서 그림에 표시된 범위와 전체 데이터 범위를 넘지 않는 것들 선택
    xticks = [x for x in ax.get_xticks()
    if (ax.get_xbound()[0] <= x <= ax.get_xbound()[1]) # 그림 표시 범위 안쪽
    and x < len(data_pt.columns)] # 전체 데이터 범위 안쪽
    ax.set_xticks(xticks)
    ax.set_xticklabels([f"{data_pt.columns[int(x)]:.3f}" for x in xticks]) # 문자열 포맷 지정

    yticks = [y for y in ax.get_yticks()
    if (ax.get_ybound()[0] <= y <= ax.get_ybound()[1])
    and y < len(data_pt.index)]
    ax.set_yticks(yticks)
    ax.set_yticklabels([f"{data_pt.index[int(y)]:.3f}" for y in yticks])


  • 이런 식으로 해당 index의 좌표를 억지로 찾아 넣을 수는 있으나 자연스럽지 못합니다.

  • pcolor 시리즈는 다른 용처를 찾아봅시다. 여긴 아닙니다.

2.4. ax.contour

matplotlib: Contour Demo
matplotlib: matplotlib.contour

  • Z 데이터를 등고선으로 표기합니다.

  • 기상도나 지형도에서 많이 본 형식이기 때문에 데이터 전달력이 좋습니다.

  • ax.imshow처럼 extent 매개변수를 사용할 수 있어 실제 좌표로 표현이 가능합니다.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    ax.contour(data.pivot_table("Z", ydata, "X"), # wide format 데이터 활용
    levels=25, cmap="inferno", # 등고선 갯수와 colormap 지정
    # extent 매개변수 적용 : x축과 y축에 실제 데이터값 매핑
    extent=[data["X"].min(), data["X"].max(), data[ydata].min(), data[ydata].max()])
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)


  • 등고선에 데이터 값을 표기해서 정확한 값을 전달하는 것이 가능합니다.

  • 데이터 값이 들어가려면 등고선의 수를 좀 줄이는 편이 좋습니다.

  • 글자가 놓이는 부분의 등고선을 생략하려면 매개변수로 inline=True를 넣고,
  • 글자 좌우 여백을 조정하고 싶으면 inline_spacing을 조절합니다.
  • 매개변수로 상당히 많은 세부사항을 조정할 수 있습니다.
  • 자세한 내용은 공식 홈페이지matplotlib.contour.ContourLabeler.clabel을 참고합니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    cs = ax.contour(data.pivot_table("Z", ydata, "X"), # contour label 삽입을 위해 객체(cs)로 받음.
    levels=15, cmap="inferno",
    extent=[data["X"].min(), data["X"].max(), data[ydata].min(), data[ydata].max()])
    ax.clabel(cs, inline=True, inline_spacing=0, fontsize=12) # contour label 설정
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)

2.5. ax.contourf()

  • 등고선 영역에 색칠을 합니다.
  • ax.imshow()를 단계별로 quantize 했다고도 볼 수 있습니다.
  • 자체적으로도 ax.clabel()로 등고선 레이블을 붙일 수 있지만 ax.contour()과 협력하면 더 좋습니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    ax.contourf(data.pivot_table("Z", ydata, "X"), # 등고선 사이 채색
    levels=25, cmap="inferno",
    extent=[data["X"].min(), data["X"].max(), data[ydata].min(), data[ydata].max()])
    cs = ax.contour(data.pivot_table("Z", ydata, "X"), # 등고선
    levels=25, colors="w", alpha=0.5,
    extent=[data["X"].min(), data["X"].max(), data[ydata].min(), data[ydata].max()])
    ax.clabel(cs, inline=True, inline_spacing=0, fontsize=12) # 등고선 레이블
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)

2.6. ax.triplot()

matplotlib.axes.Axes.triplot

  • 입력받은 데이터를 삼각형으로 잘게 자릅니다.

  • Delaunay triangulation이라고 하는데, 유한요소해석(finite element analysis)을 배운 분들께는 매우 익숙할 단어입니다.

  • 지금 데이터가 너무 빼곡하기 때문에 일부만 샘플링을 해서 데이터 수를 줄입니다.

  • x의 최대와 최소, y의 최대와 최소는 남겨서 데이터 범위는 유지합니다.

    1
    2
    3
    4
    5
    6
    7
    8
    data1_ll = data1.sort_values(["X", "Y1"], ascending=[True, True]).index[0]
    data1_ul = data1.sort_values(["X", "Y1"], ascending=[True, False]).index[0]
    data1_lr = data1.sort_values(["X", "Y1"], ascending=[False, True]).index[0]
    data1_ur = data1.sort_values(["X", "Y1"], ascending=[False, False]).index[0]

    data1_s = pd.concat([data1.loc[[data1_ll, data1_ul, data1_lr, data1_ur]], data1.sample(n=400)], axis=0)
    data1_s.reset_index(inplace=True, drop=True)
    data1_s.head()


  • data2에도 같은 처리를 합니다.

  • data2는 영역이 좁은만큼 데이터 수를 data1의 2분의 1로 줄입니다.

    1
    2
    3
    4
    5
    6
    7
    8
    data2_ll = data2.sort_values(["X", "Y2"], ascending=[True, True]).index[0]
    data2_ul = data2.sort_values(["X", "Y2"], ascending=[True, False]).index[0]
    data2_lr = data2.sort_values(["X", "Y2"], ascending=[False, True]).index[0]
    data2_ur = data2.sort_values(["X", "Y2"], ascending=[False, False]).index[0]

    data2_s = pd.concat([data2.loc[[data2_ll, data2_ul, data2_lr, data2_ur]], data2.sample(n=200)], axis=0)
    data2_s.reset_index(inplace=True, drop=True)
    data2_s.head()


  • 줄어든 데이터로 triplot을 그립니다.

  • pivot_table을 할 필요가 없습니다.

    1
    2
    3
    4
    5
    6
    7
    8
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1_s, data2_s]):
    ydata = "Y1" if data is data1_s else "Y2"
    ax.triplot(data["X"], data[ydata], lw=0.1) # ax.triplot
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)


  • 삼각형 mesh 구조가 생성되었습니다.

  • 다 좋은데, 오른쪽 그림을 보면 데이터 위쪽 활처럼 휜 곳이 직선으로 메워졌습니다.

  • all-convex 도형은 괜찮은데 안쪽으로 들어간(concave) 구조에 사용하기엔 적절치 않습니다.

2.7. ax.tricontour()

matplotlib.axes.Axes.tricontour

  • ax.contour()와 비슷한 기능입니다.
  • 차이가 있다면 삼각형으로 쪼갠 뒤에 그립니다.
  • 데이터의 정밀도를 높일 수 있는 방법입니다.
  • 데이터 구조는 보았으니 앞에서 성기게 만든 데이터보다 원 데이터로 작업합니다.
    1
    2
    3
    4
    5
    6
    7
    8
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    ax.tricontour(data["X"], data[ydata], data["Z"], levels=15, cmap="inferno")
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)

2.8. ax.tricontourf()

matplotlib.axes.Axes.tricontourf

  • ax.contourf()에 대응되는 기능입니다.

    1
    2
    3
    4
    5
    6
    7
    8
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), constrained_layout=True)

    fontlabel = {"fontsize":"large", "color":"gray", "fontweight":"bold"}
    for ax, data in zip(axs, [data1, data2]):
    ydata = "Y1" if data is data1 else "Y2"
    ax.tricontourf(data["X"], data[ydata], data["Z"], levels=15, cmap="inferno")
    ax.set_xlabel("X", fontdict=fontlabel)
    ax.set_ylabel(ydata, fontdict=fontlabel)


  • Y2가 포함된 오른쪽 그림에서 자꾸 원하지 않는 concave 영역이 나타납니다.

  • mask 매개변수를 사용해서 이런 부분을 제거할 수 있을 것으로 보입니다.

  • 미래의 나에게 숙제로 맡기겠습니다.


도움이 되셨나요? 카페인을 투입하시면 다음 포스팅으로 변환됩니다

Share