Matplotlib Vis. Function

  • 자주 사용하는 기능은 함수로 만들면 편리합니다.
  • 마찬가지로 자주 그리는 그림은 함수로 만들면 좋습니다.
  • Matplotlib 객체지향을 사용해 함수를 만듭시다.

1. Parity plot

  • 머신러닝 후 참값을 x축, 예측값을 y축에 놓고 얼마나 비슷한지 평가하고는 합니다.
  • 이런 그림을 parity plot이라고 하며, 매우 자주 그리는 그림입니다.
  • 그림이 목적이므로 데이터는 간단히 만듭니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    %matplotlib inline

    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    import pandas as pd

    # 샘플 데이터 생성
    size = 1000
    x = np.random.normal(size=size, loc=12, scale=3)
    y = x + np.random.normal(size=size, loc=0, scale=2)
  • 그리고, parity plot은 이렇게 그려집니다.

  • x가 큰 지점에서는 예측값이 실제값보다 큽니다.

  • 중앙부 기준 중심선에서 5정도 어긋난 듯 합니다.


  • 많이 본 형태라 당연하게 여길 수 있겠지만 그냥 그리면 이렇습니다.
  • 대충 일치하는 것 같기는 합니다.
  • 그런데 얼마나 일치하고 얼마나 어긋나는지 잘 모르겠습니다.
    1
    plt.scatter(x, y, alpha=0.3)

  • x와 y축의 눈금을 일치시키고, grid와 중심선까지 그었기 때문에 보이는 것입니다.

  • 코드는 이렇습니다.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    fig, ax = plt.subplots(figsize=(4, 4))

    # data plot
    ax.scatter(x, y, c="g", alpha=0.3)

    # x, y limits
    ax.set_xlim(0, 25)
    ax.set_ylim(0, 25)

    # x, y ticks, ticklabels
    ticks = [0, 5, 10, 15, 20, 25]
    ax.set_xticks(ticks)
    ax.set_xticklabels(ticks)
    ax.set_yticks(ticks)
    ax.set_yticklabels(ticks)

    # grid
    ax.grid(True)

    # 기준선
    ax.plot([0, 25], [0, 25], c="k", alpha=0.3)

    # x, y label
    font_label = {"color":"gray", "fontsize":"x-large"}
    ax.set_xlabel("true", fontdict=font_label, labelpad=8)
    ax.set_ylabel("predict", fontdict=font_label, labelpad=8)

    # title
    font_title = {"color": "gray", "fontsize":"x-large", "fontweight":"bold"}
    ax.set_title("true vs predict", fontdict=font_title, pad=16)

    # 파일로 저장
    fig.tight_layout()
    fig.savefig("73_mplfunc_01.png")
  • 34줄의 코드를 머신러닝 프로젝트때마다 짤 수는 있겠지만 귀찮습니다.

  • 함수로 만들어 자동화합시다.

2. 함수 만들기

  • python에서 함수는 def 함수이름(매개변수):로 선언함으로써 만들어집니다.
  • 그리고 함수 내부에 parity plot을 그리는 코드를 넣어주면 작동합니다.
  • 그렇다면, 함수의 결과물은 무엇으로 하는 게 좋을까요? 매개변수에는 뭘 넣을까요?
  • 간단한 예시를 만들며 고민해 봅시다

2.1. plot_sample()

matplotlib: Writing mathematical expressions

  • x, y 데이터를 입력받아 scatter plot을 그리는 함수입니다.
  • 1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    def plot_sample1(x, y, xlabel=None, ylabel=None, title=None, filename=None):
    # math fonts
    mathtext_fontset = plt.rcParams['mathtext.fontset']
    mathtext_default = plt.rcParams['mathtext.default']
    plt.rcParams['mathtext.fontset'] = "cm"
    plt.rcParams['mathtext.default'] = "it"

    # object oriented interface
    fig, ax = plt.subplots()

    # data plot
    ax.scatter(x, y)

    # x, y labels
    font_label = {"color":"gray", "fontsize":"x-large"}
    ax.set_xlabel(xlabel, fontdict=font_label, labelpad=8)
    ax.set_ylabel(ylabel, fontdict=font_label, labelpad=8)

    # title
    font_title = {"color": "darkgreen", "fontsize":"xx-large", "fontweight":"bold"}
    ax.set_title(title, fontdict=font_title, pad=16)

    # math fonts restoration
    plt.rcParams['mathtext.fontset'] = mathtext_fontset
    plt.rcParams['mathtext.default'] = mathtext_default

    # save figure
    fig.tight_layout()
    if filename:
    fig.savefig(filename)
  • 이 함수에 x와 y 데이터를 입력하면 다음과 같은 그림이 출력됩니다.

  • 함수 내부에서 fontdict를 사용해 xlabel과 ylabel, title 형태를 미리 설정했기 때문에,
    별다른 옵션을 지정하지 않았는데도 크기와 색상이 반영되어 있습니다.

  • 심지어 LaTeX 입력시 폰트도 roman(정확히는 Computer Modern)으로 설정됩니다.

    1
    2
    3
    4
    x_sample = np.linspace(0, 10, 100)
    y_sample = np.sin(x_sample)

    plot_sample1(x_sample, y_sample, "$X$", "$Y$", "$Y = \mathrm{sin}(X)$")


  • filename=에 적절한 이름을 담아 매개변수로 넣으면 파일 저장까지 자동으로 됩니다.

  • 기본값이 지정되면 사용이 편리합니다.
  • xlabel, ylabel, title, filename의 기본값 = None으로 지정했기 때문에, 함수의 인자로 x와 y만 입력해도 결과물이 출력됩니다.

2.2. 시각화 유형 변환

  • scatter plot 말고 다른 것도 그려봅시다.
  • 경우에 따라 line plot을 그리고싶다면 매개변수에 종류가 있으면 됩니다.
  • seaborn과 pandas를 따라 이름은 kind로 지정합니다.
  • 생각을 한번만 더 해봅시다.

    • matplotlib에 익숙한 이이라면 ax.scatter()가 익숙할 것입니다.
    • seaborn을 많이 쓰는 사람이라면 sns.scatterplot()이 친숙할 겁니다.
    • kind=로 전달되는 인자에 scatter가 있기만 하면 scatter plot을 그립시다.
    • if "scatter" in kind:로 구현할 수 있습니다.
    • line plot도 비슷하게 구현합니다.
  • 이제부터는 코드가 조금 길어집니다. 시각화 함수 코드는 기본적으로 숨겨두겠습니다.

  • 여기를 클릭하면 보입니다.

    plot_sample2() 코드 보기/접기
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    def plot_sample2(x, y, kind="scatter", xlabel=None, ylabel=None, title=None, filename=None):
    # math fonts
    mathtext_fontset = plt.rcParams['mathtext.fontset']
    mathtext_default = plt.rcParams['mathtext.default']
    plt.rcParams['mathtext.fontset'] = "cm"
    plt.rcParams['mathtext.default'] = "it"

    # object oriented interface
    fig, ax = plt.subplots()

    # data plot
    if "scatter" in kind:
    ax.scatter(x, y)
    elif "line" in kind:
    ax.plot(x, y)

    # x, y labels
    font_label = {"color":"gray", "fontsize":"x-large"}
    ax.set_xlabel(xlabel, fontdict=font_label, labelpad=8)
    ax.set_ylabel(ylabel, fontdict=font_label, labelpad=8)

    # title
    font_title = {"color": "darkgreen", "fontsize":"xx-large", "fontweight":"bold"}
    ax.set_title(title, fontdict=font_title, pad=16)

    # math fonts restoration
    plt.rcParams['mathtext.fontset'] = mathtext_fontset
    plt.rcParams['mathtext.default'] = mathtext_default

    # save figure
    fig.tight_layout()
    if filename:
    fig.savefig(filename)
  • line plot

    1
    plot_sample2(x_sample, y_sample, kind="line")


  • scatter plot (지정)

    1
    plot_sample2(x_sample, y_sample, kind="scatterplot")


  • scatter plot (기본값)

    1
    plot_sample2(x_sample, y_sample)


2.3. 유형에 따른 매개변수 입력

matplotlib.axes.Axes.plot
matplotlib.axes.Axes.scatter

  • Matplotlib의 plot()명령과 scatter()명령은 입력받는 인자가 다릅니다.
  • 이 인자들을 모두 매개변수로 넣자면 너무 많고 코드 관리가 어렵습니다.
  • plot()scatter()에 필요한 매개변수를 각기 line_kwsscatter_kws라는 이름의 dictionary 형식으로 입력하게 합시다.
  • dictionary 형식의 인자는 기본값을 None으로 넣고, 실제 시각화 코드에 **line_kws 형식으로 unpacking하여 입력합니다.

  • keyword arguments로 None이 들어가면 에러가 나기 때문에 간단한 예외처리를 합니다.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    # 함수 선언 부분
    def plot_sample3(x, y, kind="scatter", xlabel=None, ylabel=None, title=None,
    filename=None, line_kws=None, scatter_kws=None):

    #... 전략 ...#

    # data plot
    if "scatter" in kind:
    if not scatter_kws:
    scatter_kws={}
    ax.scatter(x, y, **scatter_kws)
    elif "line" in kind:
    if not line_kws:
    line_kws={}
    ax.plot(x, y, **line_kws)

    #... 후략 ...#
  • keyword parameter를 적용한 코드입니다.

    plot_sample3() 코드 보기/접기
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    def plot_sample3(x, y, kind="scatter", xlabel=None, ylabel=None, title=None, 
    filename=None, line_kws=None, scatter_kws=None):
    # math fonts
    mathtext_fontset = plt.rcParams['mathtext.fontset']
    mathtext_default = plt.rcParams['mathtext.default']
    plt.rcParams['mathtext.fontset'] = "cm"
    plt.rcParams['mathtext.default'] = "it"

    # object oriented interface
    fig, ax = plt.subplots()

    # data plot
    if "scatter" in kind:
    if not scatter_kws:
    scatter_kws={}
    ax.scatter(x, y, **scatter_kws)
    elif "line" in kind:
    if not line_kws:
    line_kws={}
    ax.plot(x, y, **line_kws)

    # x, y labels
    font_label = {"color":"gray", "fontsize":"x-large"}
    ax.set_xlabel(xlabel, fontdict=font_label, labelpad=8)
    ax.set_ylabel(ylabel, fontdict=font_label, labelpad=8)

    # title
    font_title = {"color": "darkgreen", "fontsize":"xx-large", "fontweight":"bold"}
    ax.set_title(title, fontdict=font_title, pad=16)

    # math fonts restoration
    plt.rcParams['mathtext.fontset'] = mathtext_fontset
    plt.rcParams['mathtext.default'] = mathtext_default

    # save figure
    fig.tight_layout()
    if filename:
    fig.savefig(filename)
  • line_kws 적용
    1
    2
    3
    4
    5
    plot_sample3(x_sample, y_sample, kind="line",
    line_kws={"c":"r", "ls":":", "lw":3})
    # "c": "r" - line color = "red"
    # "ls": ":" - line style = ......
    # "lw": 3 - line width = 3

  • scatter_kws 적용 : line_kws는 무시됩니다.
    1
    2
    3
    4
    5
    6
    plot_sample3(x_sample, y_sample, kind="scatter",
    line_kws={"c": "r", "ls": ":", "lw": 3},
    scatter_kws={"s": 50, "ec": "b", "alpha": 0.2})
    # "s": 50 - marker size = 50
    # "ec": "b" - marker color = "blue"
    # "alpha": 0.2 - marker 불투명도 = 0.2

  • 이제 웬만한 함수는 원하는대로 만들 수 있습니다.
  • 그런데, 한번 만들고 끝일까요?
  • 함수를 실행할 때는 title을 달지 않았는데, 나중에 달고 싶지 않을까요?
  • 그럴 때 return이 유용합니다.

2.4. Axes as return

  • matplotlib의 구성요소인 axes를 return 시키면 많은 것이 가능합니다.
  • 먼저, 기존의 코드에 return ax만 추가하고 실행해 봅니다.
    plot_sample4() 코드 보기/접기
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    def plot_sample4(x, y, kind="scatter", xlabel=None, ylabel=None, title=None, 
    filename=None, line_kws=None, scatter_kws=None):
    # math fonts
    mathtext_fontset = plt.rcParams['mathtext.fontset']
    mathtext_default = plt.rcParams['mathtext.default']
    plt.rcParams['mathtext.fontset'] = "cm"
    plt.rcParams['mathtext.default'] = "it"

    # object oriented interface
    fig, ax = plt.subplots()

    # data plot
    if "scatter" in kind:
    if not scatter_kws:
    scatter_kws={}
    ax.scatter(x, y, **scatter_kws)
    elif "line" in kind:
    if not line_kws:
    line_kws={}
    ax.plot(x, y, **line_kws)

    # x, y labels
    font_label = {"color":"gray", "fontsize":"x-large"}
    ax.set_xlabel(xlabel, fontdict=font_label, labelpad=8)
    ax.set_ylabel(ylabel, fontdict=font_label, labelpad=8)

    # title
    font_title = {"color": "darkgreen", "fontsize":"xx-large", "fontweight":"bold"}
    ax.set_title(title, fontdict=font_title, pad=16)

    # math fonts restoration
    plt.rcParams['mathtext.fontset'] = mathtext_fontset
    plt.rcParams['mathtext.default'] = mathtext_default

    # save figure
    fig.tight_layout()
    if filename:
    fig.savefig(filename)

    # return
    return ax
  • line plot을 그립니다.
  • return된 axes에는 방금 그린 그림의 정보가 모두 포함되어 있습니다.
    1
    ax = plot_sample4(x_sample, y_sample, kind="line", line_kws={"lw":3})

  • axes에 xlabel, ylabel, title을 추가할 수 있습니다.
  • plot 추가도 가능합니다. 코드도 일반적인 시각화와 동일합니다.
  • 심지어 순차적으로 적용되는 line color도 그냥 그리는 그림과 같습니다.
  • 당연합니다. 객체지향 방식이니까요 :)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    # xlabel, ylabel, title 추가
    ax.set_xlabel("xlabel (postprocess)")
    ax.set_ylabel("ylabel (postprocess)")
    ax.set_title("title (postprocess)")

    # plot 추가
    ax.plot(x_sample+1, y_sample)

    # jupyter cell에서 시각화
    display(ax.figure)

※ 주의 ※

  • 그러나 자세히 보면 x, y label에는 설정된 값들이 제대로 적용되지만 title에는 색상만 적용됩니다.
  • title=None일 때 matplotlib axes가 담는 정보가 불충분한 것으로 생각됩니다.

2.5. Axes as input

  • axes는 함수의 입력 매개변수로 작용할 때 그 진가를 발합니다.
    • 함수복잡한 명령을 한번에 실행한다는 장점이 있지만 유연성이 부족합니다.
    • 날코딩유연성이 풍족하지만 일일이 코딩하기 번잡합니다.
    • 이 둘을 섞을 수 있는 방법이 axes를 input으로 받는 것입니다.
  • 머신러닝 예측결과 시각화로 예를 들어보겠습니다.

    • parity plot은 실제값과 예측값을 비교하는 그림입니다.
    • trainset, validation set, testset 세 데이터 모두에 대해 그릴 수 있습니다.
    • 이 중 하나만 그릴 때도 있고 둘만, 셋 다 그릴 때가 있습니다.
    • 이 때마다 함수를 일일이 만든다면 몹시 번거로울 것입니다.
  • 이럴 때 이런 해법을 만들 수 있습니다.

    1. plt.subplots()등으로 필요한 수만큼 Axes을 만듭니다.
    2. 준비된 함수로 각각의 Axes에 parity plot을 그립니다
  • 그러자면, 함수로 그려질 그림이 어디에 그려질지 지정되어야 합니다.
  • 매개변수로 axes를 받으면 가능합니다.
  • axes가 지정되지 않으면 스스로 figure를 만들도록 합니다. 이 때 figure size도 인자로 넣읍시다.
  • 파일로 저장하려면 figure 객체가 필요합니다. figure 객체는 axes 입력이 없을 때만 존재하니 예외처리를 합니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    def plot_sample5(x, y, kind="scatter", xlabel=None, ylabel=None, title=None, 
    filename=None, line_kws=None, scatter_kws=None,
    figsize=plt.rcParams['figure.figsize'], ax=None):

    #... 전략 ...#

    # object oriented interface
    if not ax:
    fig, ax = plt.subplots(figsize=figsize)

    #... 후략 ...#

    # save figure
    if not ax:
    fig.tight_layout()
    if filename:
    fig.savefig(filename)

    # return
    return ax
    plot_sample5() 코드 보기/접기
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    def plot_sample5(x, y, kind="scatter", xlabel=None, ylabel=None, title=None, 
    filename=None, line_kws=None, scatter_kws=None,
    figsize=plt.rcParams['figure.figsize'], ax=None):
    # math fonts
    mathtext_fontset = plt.rcParams['mathtext.fontset']
    mathtext_default = plt.rcParams['mathtext.default']
    plt.rcParams['mathtext.fontset'] = "cm"
    plt.rcParams['mathtext.default'] = "it"

    # object oriented interface
    if not ax:
    fig, ax = plt.subplots(figsize=figsize)

    # data plot
    if "scatter" in kind:
    if not scatter_kws:
    scatter_kws={}
    ax.scatter(x, y, **scatter_kws)
    elif "line" in kind:
    if not line_kws:
    line_kws={}
    ax.plot(x, y, **line_kws)

    # x, y labels
    font_label = {"color":"gray", "fontsize":"x-large"}
    ax.set_xlabel(xlabel, fontdict=font_label, labelpad=8)
    ax.set_ylabel(ylabel, fontdict=font_label, labelpad=8)

    # title
    font_title = {"color": "darkgreen", "fontsize":"xx-large", "fontweight":"bold"}
    ax.set_title(title, fontdict=font_title, pad=16)

    # math fonts restoration
    plt.rcParams['mathtext.fontset'] = mathtext_fontset
    plt.rcParams['mathtext.default'] = mathtext_default

    # save figure
    if not ax:
    fig.tight_layout()
    if filename:
    fig.savefig(filename)

    # return
    return ax
  • 그냥 그리기
    1
    plot_sample5(x_sample, y_sample)
  • subplots 안에 넣기
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    fig, axs = plt.subplots(ncols=2, figsize=(8, 4), sharey=True)

    font_title = {"fontweight":"bold", "fontsize":"x-large"}

    # 왼쪽 axes
    plot_sample5(x_sample, y_sample, ax=axs[0])
    plot_sample5(10-x_sample, y_sample, ax=axs[0])
    axs[0].set_title("plot 1", fontdict=font_title)

    # 오른쪽 axes
    plot_sample5(x_sample, -y_sample, kind="line", line_kws={"c": "r"}, ax=axs[1])
    plot_sample5(x_sample, 1-y_sample, kind="line", line_kws={"c": "g"}, ax=axs[1])
    axs[1].set_title("plot 1", fontdict=font_title)

    # 그림 저장
    fig.tight_layout()
    fig.savefig("73_mplfunc_09.png")
  • 새로 그린 그림이 plt.subplots()로 만든 axes에 정확히 담겼습니다.
  • 그 뿐 아니라 sharey=True 와 같은 axes간 제약조건도 적용됩니다.
  • 함수의 문법이 어디선가 본 것 같다고 생각하셨으면 맞게 본 것입니다.
  • seaborn의 함수들이 바로 이렇게 만들어졌고 작동합니다.

3. parity plot

  • 다시 parity plot으로 돌아갑니다.
  • 아래는 제가 만든 함수로 그린 parity plot들입니다.
  • 다양한 경우에 활용할 수 있음을 알 수 있습니다.

    case I

case II

  • 코드입니다. 세 가지 지표를 평가하여 그림에 함께 담습니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    import seaborn as sns
    import matplotlib.colors as colors
    from sklearn.metrics import mean_absolute_error
    from sklearn.metrics import mean_squared_error
    from sklearn.metrics import r2_score

    def get_metrics(true, predict):
    mae = mean_absolute_error(true, predict)
    rmse = mean_squared_error(true, predict, squared=False)
    r2 = r2_score(true, predict)

    return mae, rmse, r2

    def plot_parity(true, pred, kind="scatter",
    xlabel="true", ylabel="predict", title="true vs predict",
    hist2d_kws=None, scatter_kws=None, kde_kws=None,
    equal=True, metrics=True, metrics_position="lower right",
    figsize=(4, 4), ax=None, filename=None):

    if not ax:
    fig, ax = plt.subplots(figsize=figsize)

    # data range
    val_min = min(true.min(), pred.min())
    val_max = max(true.max(), pred.max())

    # data plot
    if "scatter" in kind:
    if not scatter_kws:
    scatter_kws={'color':'green', 'alpha':0.5}
    ax.scatter(true, pred, **scatter_kws)
    elif "hist2d" in kind:
    if not hist2d_kws:
    hist2d_kws={'cmap':'Greens', 'vmin':1}
    ax.hist2d(true, pred, **hist2d_kws)
    elif "kde" in kind:
    if not kde_kws:
    kde_kws={'cmap':'viridis', 'levels':5}
    sns.kdeplot(x=true, y=pred, **kde_kws, ax=ax)

    # x, y bounds
    xbounds = ax.get_xbound()
    ybounds = ax.get_ybound()
    max_bounds = [min(xbounds[0], ybounds[0]), max(xbounds[1], ybounds[1])]
    ax.set_xlim(max_bounds)
    ax.set_ylim(max_bounds)

    # x, y ticks, ticklabels
    ticks = [int(y) for y in ax.get_yticks() if (10*y)%10 == 0]
    ax.set_xticks(ticks)
    ax.set_xticklabels(ticks)
    ax.set_yticks(ticks)
    ax.set_yticklabels(ticks)

    # grid
    ax.grid(True)

    # 기준선
    ax.plot(max_bounds, max_bounds, c="k", alpha=0.3)

    # x, y label
    font_label = {"color":"gray", "fontsize":"x-large"}
    ax.set_xlabel(xlabel, fontdict=font_label, labelpad=8)
    ax.set_ylabel(ylabel, fontdict=font_label, labelpad=8)

    # title
    font_title = {"color": "gray", "fontsize":"x-large", "fontweight":"bold"}
    ax.set_title(title, fontdict=font_title, pad=16)

    # metrics
    if metrics:
    rmse = mean_squared_error(true, pred, squared=False)
    mae = mean_absolute_error(true, pred)
    r2 = r2_score(true, pred)

    font_metrics = {'color':'k', 'fontsize':12}

    if metrics_position == "lower right":
    text_pos_x = 0.98
    text_pos_y = 0.3
    ha = "right"
    elif metrics_position == "upper left":
    text_pos_x = 0.1
    text_pos_y = 0.9
    ha = "left"
    else:
    text_pos_x, text_pos_y = text_position
    ha = "left"

    ax.text(text_pos_x, text_pos_y, f"RMSE = {rmse:.3f}",
    transform=ax.transAxes, fontdict=font_metrics, ha=ha)
    ax.text(text_pos_x, text_pos_y-0.1, f"MAE = {mae:.3f}",
    transform=ax.transAxes, fontdict=font_metrics, ha=ha)
    ax.text(text_pos_x, text_pos_y-0.2, f"R2 = {r2:.3f}",
    transform=ax.transAxes, fontdict=font_metrics, ha=ha)

    # 파일로 저장
    if not ax:
    fig.tight_layout()
    if filename:
    fig.savefig(filename)

    return ax


도움이 되셨나요?

카페인을 투입하시면 다음 포스팅으로 변환됩니다

Share