Notice
Recent Posts
Recent Comments
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
Archives
Today
Total
관리 메뉴

SYDev

Chapter 01-3: 마켓과 머신러닝 본문

KHUDA 4th/머신러닝 기초 세션

Chapter 01-3: 마켓과 머신러닝

시데브 2023. 7. 30. 16:44

생선 분류 문제

  • 생선의 특성인 '길이'와 '무게' 데이터를 가지고 생선 '도미'와 '빙어'를 분류하는 모델을 만들어보자.

 

용어 정리

  • 분류(classification): 머신러닝에서 여러 개의 종류(혹은 클래스(class)) 중 하나를 구별해내는 문제를 이르는 말.
  • 특성(feature): 학습 모델로 정답을 도출해내기 위해 고려하는 데이터
  • 샘플(sample): 학습 데이터에 포함된 하나의 특징벡터

 위 문제에서는 '도미'와 '빙어'가 각각 분류해야 하는 클래스이며, 생선의 길이와 무게가 특성에 해당된다. 마지막으로 생선의 길이 데이터에 포함된 각각의 값(예를 들어 25.4cm)이 샘플이다.

도미, 빙어 데이터

bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0,
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0,
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0,
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0,
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

 matplotlib을 이용하여 길이를 x축, 무게를 y축으로 각 데이터를 그래프(이를 산점도(scatter plot)라 부른다)로 그리면 다음과 같다. 

import matplotlib.pyplot as plt #matplotlib의 pylot 함수를 plt로 줄여서 사용
plt.scatter(bream_length, bream_weight) #입력받은 인자들의 산점도 나타냄
plt.scatter(smelt_length, smelt_weight) #입력받은 인자들의 산점도 나타냄
plt.xlabel('length')  #x축 이름 표시
plt.ylabel('weight')  #y축 이름 표시
plt.show()  #그래프 화면에 출력

  • 산점도에서 그래프가 일직선에 가깝게 나타나는 것선형(linear)적이라 한다.
  • 그래프를 살펴봤을 때, 도미와 빙어의 그래프 모두 선형적이라 할 수 있다.

 

두 클래스의 분류하는 모델

  • k-최근접 이웃(k-Nearest Neighbors) 알고리즘을 사용해 도미와 빙어 데이터를 구분해보자.
  • k-최근접 이웃 알고리즘(k-NN 알고리즘): 주변의 가장 가까운 k개의 데이터를 보고 데이터가 속할 그룹을 판단하는 알고리즘

 

학습 데이터와 평가 데이터 준비

 우선 학습을 위해서 다음과 같이 length와 weight 리스트를 합쳐 학습 데이터를 만든다.

length = bream_length + smelt_length  #두 리스트가 가진 데이터를 모두 합친 하나의 리스트 length
weight = bream_weight + smelt_weight  #두 리스트가 가진 데이터를 모두 합친 하나의 리스트 weight

 다음으로 사이킷런(scikit-learn) 패키지를 사용하기 위해 각 특성의 리스트를 세로 방향으로 늘어뜨린 2차원 리스트로 만들어야 한다.

fish_data = [[l, w] for l, w in zip(length, weight)]  #length와 weight에서 원소를 하나씩 꺼내어 l과 w에 할당 -> [l, w]가 하나의 원소로 구성된 리스트

 

 fish_data를 출력한 결과는 다음과 같고, 이는 훈련 데이터이다.

print(fish_data)  #2차원 리스트 출력
[[25.4, 242.0], [26.3, 290.0], [26.5, 340.0], [29.0, 363.0], [29.0, 430.0], [29.7, 450.0],
[29.7, 500.0], [30.0, 390.0], [30.0, 450.0], [30.7, 500.0], [31.0, 475.0], [31.0, 500.0],
[31.5, 500.0], [32.0, 340.0], [32.0, 600.0], [32.0, 600.0], [33.0, 700.0], [33.0, 700.0],
[33.5, 610.0], [33.5, 650.0], [34.0, 575.0], [34.0, 685.0], [34.5, 620.0], [35.0, 680.0],
[35.0, 700.0], [35.0, 725.0], [35.0, 720.0], [36.0, 714.0], [36.0, 850.0], [37.0, 1000.0],
[38.5, 920.0], [38.5, 955.0], [39.5, 925.0], [41.0, 975.0], [41.0, 950.0], [9.8, 6.7],
[10.5, 7.5], [10.6, 7.0], [11.0, 9.7], [11.2, 9.8], [11.3, 8.7], [11.8, 10.0], [11.8, 9.9],
[12.0, 9.8], [12.2, 12.2], [12.4, 13.4], [13.0, 12.2], [14.3, 19.7], [15.0, 19.9]]

 이제 마지막으로 머신러닝이 학습할 때 필요로 하는 정답 데이터(1은 도미, 0은 빙어를 의미)를 준비해야 한다.

fish_target = [1] * 35 + [0] * 14
print(fish_target)
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

 

도미와 빙어 분류

  • KNeighborsClassifier: 사이킷런 패키지에서 k-최근접 이웃 알고리즘을 구현한 클래스
from sklearn.neighbors import KNeighborsClassifier

-> 사이킷런 패키지에서 k-최근접 이웃 알고리즘을 구현한 클래스인 KNeighborsClassifier를 임포트

 

kn = KNeighborsClassifier()

-> 클래스 KNeighborsClassifier의 객체 kn을 생성

 

kn.fit(fish_data, fish_target)

-> 사이킷런의 메서드 fit()을 이용하여 객체에 학습 데이터, 평가 데이터 전달을 통해 해당 객체를 학습

(이런 과정을 머신러닝에서는 훈련(training)이라 함)

 

kn.score(fish_data, fish_target)

-> 메서드 score()를 이용하여 객체 kn이 얼마나 잘 훈련되었는지 평가, score()는 0에서 1사이의 값을 반환

1.0

-> score()의 반환값으로 1이 반환됐으므로, 해당 모델은 100%의 정확도를 가짐을 알 수 있음

 

k-최근접 이웃 알고리즘

predict()

  • 클래스 KNeighborsClassifier의 메서드 predict()를 이용하여 샘플 데이터가 어느 클래스에 속하는지 판단할 수 있다.

 예를 들어 길이가 30이고 무게가 600인 생선이 있다면, 아래 결과를 통해 해당 생선은 도미임을 알 수 있다.

kn.predict([[30,600]])
array([1])

 

k-최근접 이웃 알고리즘의 단점

  • 알고리즘 특성상 전달받은 모든 데이터를 가지고 있어야 하므로, 데이터의 양이 매우 많은 경우에는 사용하기 어렵다.
  • 데이터의 양이 매우 많은 경우, 메모리가 많이 필요하고 직선거리를 계산하는 데도 많은 시간이 필요하다.

 

매개변수 n_neighbors

  • k-최근접 이웃 알고리즘이 새로운 데이터를 받았을 때, 참고하는 주변 데이터의 개수 디폴트는 5개이다.
  • 매개변수 n_neighbors를 통해서 참고하는 주변 데이터의 개수를 바꿀 수 있다.

예를 들어 매개변수 n_neighbors를 49로 설정하면 참고 데이터가 49개가 되고, fish_data에 있는 모든 데이터를 참고하게 된다. 이 경우에는 49개 중에 35개를 차지한 도미가 다수를 차지하므로, 어떤 데이터를 넣어도 해당 데이터를 도미로 예측할 것이다.

kn49 = KNeighborsClassifier(n_neighbors=49)
kn49.fit(fish_data, fish_target)
kn49.score(fish_data, fish_target)
0.7142857142857143
print(35/49)  # 49개 중에 35개를 맞춘 정확도
0.7142857142857143

 


참고자료

 

python : 머신러닝 기본 용어 정리

특징, 속성 (feature, attribute)특징이란 학습 모델로 정답을 도출하기 위해 고려할 데이터를 의미합니다. 의미 있는 특징이 많으면 그만큼 학습이 용이합니다. 일련의 특징을 특징 벡터라고도 합니

jjeongil.tistory.com

 

[머신러닝] K-최근접 이웃(K-NN) 알고리즘 및 실습

[목차] 1. K-NN 알고리즘이란? 2. K-NN 알고리즘 실습 3. K-NN 알고리즘 실습 (훈련 셋과 데이터 셋 분리) 4. K-NN 알고리즘의 주의점 1. K-NN 알고리즘이란? K-최근접 이웃(K-NN, K-Nearest Neighbor) 알고리즘은 가

rebro.kr

  • 박해선, <혼자 공부하는 머신러닝+딥러닝>, 한빛미디어(주), 2022.2.4