Legend Control

  • Legend(범례)는 데이터의 의미 파악을 도와주는 도구입니다.
  • 그러나 그림이 여럿 있을 때 각각 붙은 Legend는 방해가 되기도 합니다.
  • Legend를 한데 모아 그리는 방법을 알아봅니다.

1. Sample Data

  • 먼저 필요한 라이브러리들을 불러오고,
1
2
3
4
5
6
7
8
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
sns.set_context("talk")
  • 우리의 펭귄을 소환합니다.
1
2
df_p = sns.load_dataset("penguins")
df_p.head()


2. 기본 그림

  • legend를 붙일 그림을 먼저 그립니다.
  • seaborn의 regplot을 사용해서 부리 길이, 폭, 날개 길이를 그립니다.
  • scatter_kws와 line_kws로 시각화 요소들의 색상, 크기 등을 설정합니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
fig, axs = plt.subplots(ncols=3, figsize=(10, 4), 
sharex=True, constrained_layout=True)

sns.regplot(x="body_mass_g", y="bill_length_mm", data=df_p,
ax=axs[0], label="Bill Length", scatter_kws={"s":10, "color":"C1"}, line_kws={"color":"k"})
sns.regplot(x="body_mass_g", y="bill_depth_mm", data=df_p,
ax=axs[1], label="Bill Depth", scatter_kws={"s":10, "color":"C2"}, line_kws={"color":"k"})
sns.regplot(x="body_mass_g", y="flipper_length_mm", data=df_p,
ax=axs[2], label="Flipper Length", scatter_kws={"s":10, "color":"C3"}, line_kws={"color":"k"})

for ax in axs:
ax.set_xlabel("Body Mass")
ax.set_ylabel("")


  • sns.regplot()함수가 데이터 수만큼 반복되고 있습니다.
  • 데이터 수가 10개라면 코드가 그만큼 더 길어질 것입니다.
  • for loop과 zip을 사용해서 효율적으로 바꿉니다. 조금 짧아지고 유지보수가 편해집니다.
1
2
3
4
5
6
7
8
9
fig, axs = plt.subplots(ncols=3, figsize=(10, 4), 
sharex=True, constrained_layout=True)

for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")


  • ax.set_ylabel("")로 ylabel을 지웠습니다.
  • y 인자 이름을 축 레이블 대신 legend 형태로 표현하기 위해서입니다.

3. Legend 하나씩

3.1. Axes별 Legend

  • 가장 기본적인 형태입니다.
  • Axes 하나마다 ax.legend()를 실행합니다.
  • markerscale은 데이터를 의미하는 마커를 3배 크게 그리라는 의미입니다.
  • 잘 보이게 하고 데이터와 혼동되지 않게 하려는 의도입니다.
1
2
3
4
5
6
7
8
9
10
11
fig, axs = plt.subplots(ncols=3, figsize=(10, 4), 
sharex=True, constrained_layout=True)

for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

ax.legend(markerscale=3) # Axes별 Legend


3.2. Axes 공간 전체 사용

  • Axes마다 붙긴 했는데 깔끔하지 않습니다. 좀 지저분합니다.
  • Axes마다 Legend가 귀퉁이에 쭈그리고 있어서 그런가 싶습니다.
  • mode="extend"로 전체 공간을 다 사용하도록 합니다.
1
2
3
4
5
6
7
8
9
10
11
fig, axs = plt.subplots(ncols=3, figsize=(10, 4), 
sharex=True, constrained_layout=True)

for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

ax.legend(mode="expand", markerscale=3) # Axes별 Legend, 넓게


3.3. Axes 위에 Legend

  • scatter plot은 점 하나하나가 데이터입니다.
  • 시각화 요소들에 의해 가려지면 그만큼 데이터 전달력이 손실됩니다.
  • Axes 위로 Legend를 올려서 데이터를 잘 보이게 합니다.
1
2
3
4
5
6
7
8
9
10
11
12
fig, axs = plt.subplots(ncols=3, figsize=(10, 4), 
sharex=True, constrained_layout=True)

for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

# Axes별 Legend, 위치 지정
ax.legend(loc="lower center", bbox_to_anchor=[0.5, 1.03], markerscale=3)


  • 그런데 여기서 전체 범위를 사용하겠다고 mode="extend"를 사용하면 오류가 납니다.
1
ax.legend(loc="lower center", bbox_to_anchor=[0, 1.03], mode="expand", markerscale=3)


  • 매개변수에서 bbox_to_anchor를 제거하고 사용하면 잘 됩니다.
1
ax.legend(loc=[0, 1.03], mode="expand", borderaxespad=0, markerscale=3)


4. Legend 모아 붙이기

  • Axes별로 Legend를 출력하지 않고 한데 모으면 더 깔끔합니다.

  • 이를 가능하게 하려면 handle과 label이라는 개념을 파악할 필요가 있습니다.

  • legend는 의미가 담긴 label과 label이 지칭하는 대상이 있습니다. 이 대상이 handle입니다.

  • ax.get_legend_handles_labels 명령으로 확인하고 가져올 수 있습니다.

  • 위 그림의 첫번째 Axes에 담긴 handle과 label은 이렇습니다.

1
2
3
handle, label = axs[0].get_legend_handles_labels()
print(f"handle= {handle}")
print(f"label = {label}")
  • 실행 결과: 시각화를 하나밖에 안했으므로 handle과 label이 하나씩입니다.
1
2
handle= [<matplotlib.collections.PathCollection object at 0x7f95fa482e90>]
label = ['Bill Length']

4.1. Axes에 붙이기

  • 빈 list를 만들고, 그림을 그릴 때마다 handle과 label을 가져와 모읍니다.
  • 그림을 모두 다 그린 후, 맨 마지막 Axes 오른쪽에 붙입니다.
  • 추가 공간이 필요하니 그림 가로 폭을 10에서 14로 넓혀줍니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
fig, axs = plt.subplots(ncols=3, figsize=(14, 4), 
sharex=True, constrained_layout=True)

handles = []
labels = []
for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

# handles, labels 모으기
handle, label = ax.get_legend_handles_labels()
handles.append(*handle)
labels.append(*label)

# 맨 오른쪽 Axes에 붙이기
axs[-1].legend(handles=handles, labels=labels, markerscale=3,
loc="upper left", bbox_to_anchor=[1.1, 1])


4.2. Figure에 붙이기

  • 특정 Axes에 속하지 않도록 전체 그림이 담긴 Figure에 붙일 수 있습니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
fig, axs = plt.subplots(ncols=3, figsize=(14, 4), 
sharex=True, constrained_layout=True)

handles = []
labels = []
for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

# handles, labels 모으기
handle, label = ax.get_legend_handles_labels()
handles.append(*handle)
labels.append(*label)

# Figure에 붙이기
fig.legend(handles=handles, labels=labels, markerscale=3)


4.2. Figure의 Axes 옆자리에 붙이기

  • Axes를 그대로 놔두고 붙였더니 맨 우측 Axes에 겹쳐 그려졌습니다.

  • Axes 옆에 놓기 위해 legend의 위치를 섬세하게 지정합니다.

  • loc와 bbox_to_anchor 매개변수는 이런 역할을 합니다.

  • loc만 단독으로 사용하면 붙이는 대상에 따라 Figure나 Axes의 지정된 위치에 놓입니다

  • loc와 bbox_to_anchor를 함께 사용하면 loc는 legend의 지점이 되고 bbox_to_anchor는 legend가 놓일 위치가 됩니다.

  • bbox_to_ancher에 매개변수가 둘 들어가면 위치만, 넷 들어가면 위치와 가로세로 크기입니다.


  • 맨 우측 Axes 오른쪽 상단에 붙도록 값을 지정합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
fig, axs = plt.subplots(ncols=3, figsize=(14, 4), 
sharex=True, constrained_layout=True)

handles = []
labels = []
for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

# handles, labels 모으기
handle, label = ax.get_legend_handles_labels()
handles.append(*handle)
labels.append(*label)

# figure 옆에 legend 붙이기: 화면엔 정상, 파일은 실패
fig.legend(handles=handles, labels=labels, markerscale=3,
loc="upper left", bbox_to_anchor=[1, 1])


  • 화면에는 정상으로 나오지만 파일을 저장하면 그렇지 않습니다.


  • legend가 전혀 보이지 않습니다.

  • bbox_to_anchor에서 지정한 x 위치가 Figure의 우측 한계선(1)을 넘었기 때문입니다.

  • 파일 출력을 하려면 legend 전체가 Figure 범위 안에 들어와야 합니다.

  • 그러려면 Axes를 좌측으로 압축시킬 필요가 있습니다.

  • fig.tight_layout()에 rect 매개변수를 넣으면 됩니다.

  • 충돌 방지를 위해 비슷한 기능을 하는 constrained_layout은 figure 생성 명령에서 삭제합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
fig, axs = plt.subplots(ncols=3, figsize=(14, 4), sharex=True)

handles = []
labels = []
for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

# handles, labels 모으기
handle, label = ax.get_legend_handles_labels()
handles.append(*handle)
labels.append(*label)

# Figure 옆에 legend 붙이기. x 좌표 = 0.8
fig.legend(handles=handles, labels=labels, markerscale=3,
loc="upper left", bbox_to_anchor=[0.8, 0.95])
# Axes들 0.8 안쪽으로 압축
fig.tight_layout(rect=[0,0,0.8,1])


4.3. Figure의 Axes 위에 붙이기

  • 같은 요령으로 legend를 Axes 위에 모아서 붙일 수 있습니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
fig, axs = plt.subplots(ncols=3, figsize=(14, 4), sharex=True)

handles = []
labels = []
for ax, c, y in zip(axs, ["C1", "C2", "C3"], ["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]):
label = " ".join([w[0].upper()+w[1:] for w in y.split("_")[:-1]])
sns.regplot(x="body_mass_g", y=y, data=df_p,
ax=ax, label=label, scatter_kws={"s":10, "color":c}, line_kws={"color":"k"})
ax.set_xlabel("Body Mass")
ax.set_ylabel("")

# handles, labels 모으기
handle, label = ax.get_legend_handles_labels()
handles.append(*handle)
labels.append(*label)

# Figure 위에 legend 붙이기. y 좌표 = 0.9
fig.legend(handles=handles, labels=labels, markerscale=3, ncol=3,
loc="upper left", bbox_to_anchor=[0.045, 0.9, 1, 0.1])
# Axes들 y = 0.9 안쪽으로 압축
fig.tight_layout(rect=[0,0,1,0.9])


5. 정리

  • legend는 여러 데이터를 명확히 구분해주는, 반드시 필요한 요소입니다.
  • 그러나 Axes가 많아지고 데이터 인자가 많아질수록 혼돈의 원인이 되기도 합니다.
  • 적절한 위치에 적절한 형식으로 배치해서 인지능력 향상에 도움이 되면 좋겠습니다.


도움이 되셨나요? 카페인을 투입하시면 다음 포스팅으로 변환됩니다

Share