본문 바로가기
Book Review/[혼공머신] 리뷰

k-최근접 이웃을 통한 분류 모델 훈련

by 3n952 2022. 11. 16.

 

혼공머신에 첫번째 챕터에서는 k-최근접 이웃(knn)을 통해 생선 이름을 자동으로 알려주는 머신러닝을 만듭니다.


 

생선 데이터셋은 캐글에 공개된 데이터 셋입니다.

http://www.kaggle.com/aungpyaeap/fish-market

 

Fish market

Database of common fish species for fish market

www.kaggle.com

1.  생선 분류 문제

 

fish market 데이터에는 다양한 생선 데이터가 들어있습니다. 이 데이터를 통해 생선을 분류하는 게 우리의 목적입니다.

 

그렇다면 어떠한 기준으로 생선을 나눠야 잘 분류했다고 할 수 있을까요?

생선의 크기와 길이 혹은 무게와 같은 절대적인 기준으로 분류를 하는 것은 부적절합니다.

ex) "50cm가 넘으면 도미야" , "1kg보다 가벼우면 도미가 아니야"

 

같은 종류의 생선의 크기가 같을 리가 없으며 종류가 다르더라도 크기나 길이에 큰 차이가 없다면 더더욱 그렇다.

--> 이러한 문제를 머신러닝을 통해 해결한다! :  분류 기준을 찾고 생선 종류를 판별한다.

 

2. k-최근접 이웃 알고리즘 (knn)

 

분류를 위한 다양한 알고리즘이 있지만 여기서는 k-최근접 이웃 알고리즘을 사용하여 분류합니다.

k-최근접 이웃 알고리즘은 어떠한 데이터에 대한 답을 구할 때 주위의 다른 데이터를 보고 다수를 차지하는 것(다수결의 원칙)을 예측값으로 사용합니다.

 

그림1) k-최근접 이웃 알고리즘

그림1에서 ?의 데이터가 새로 들어왔다고 하자. 

?의 데이터는 참고할 이웃의 갯수에 따라 별데이터가 될지,

삼각형 데이터가 될 지 정해진다.

k=3일 때는 ?데이터에 가장 가까운 값으로 삼각형 2개, 별 1개의 데이터가 있다. 

( 2/3(삼각형) > 1/3(별) ) 

따라서, k=3일때 ?의 데이터는 삼각형 데이터라고 분류하게 되는 것이다.

(여기서 가장 가까운 거리를 정할 때 적용하는 다양한 식이 있지만 그건 차후에 알아보고 여기서는 시각적으로 봤을 때(원으로 표현)

가장 가까운 k개의 이웃을 정했다.)

k=7일 때의 ?데이터는 별 데이터로 분류할 것이다.

( 4/7(별) > 3/7(삼각형) )

 

 

3. 도미 데이터 & 빙어 데이터로 이진 분류하기

 

이번 장에서는 fish market의 데이터 셋에서 도미와 빙어 데이터 일부를 가져와서 이진 분류를 진행합니다.

 

# bream 도미 데이터(길이와 무게): 35개

bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]

 

# smelt 빙어 데이터(길이와 무게): 14개

smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

 

#도미와 빙어 데이터 산점도

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

주황색 점이 빙어, 파란색 점이 도미의 산점도

빙어와 도미의 산점도는 선형적이다. feature값인 length와 weight처럼 특성 간의 관계를 통해 분류를 할 수 있다.(2장에서 다룬다.)

 

 

4. 머신러닝 모델 구현

 

도미와 빙어의 길이, 무게의 리스트를 합쳐 하나의 데이터로 만들었습니다. 코드는 다음과 같습니다.

length = bream_length+smelt_length
weight = bream_weight+smelt_weight

 

머신러닝 패키지인 사이킷런(scikit-learn)을 사용하기 위해 각 특성의 리스트를 세로 방향으로 늘어뜨려 2차원 리스트로 만듭니다.

여기서는 zip()함수와 리스트 내포(list comprehension)구문을 사용하여 2차원 리스트로 만듭니다. 코드는 다음과 같습니다.

fish_data = [[l,w] for l,w in zip(length, weight)]

zip()함수는 나열된 리스트에서 원소를 하나씩 꺼내주는 역할을 합니다. length와 weight 리스트에 저장된 각각의 원소에 인덱스 별로 접근하여 l, w로 받아 2차원 리스트를 만듭니다.  (2차원리스트 형태 : [[a,b,c],[a,b,w],[f,g,x],[r,x,g], ....,[g,x,r]] )

 

fish_data를 보면 생선 49마리에 대한 길이, 무게 데이터가 만들어졌습니다. 

그 후에는 정답 데이터가 필요합니다.

각 길이, 무게 데이터와 정답 데이터가 있어야 분류를 할 수 있습니다. (supervised learning, 지도학습)

정답 레이블로 도미는 '1', 빙어는 '0'으로 표현하겠습니다.

 

fish_data에는 도미와 빙어 리스트를 단순히 +해서 합쳤기 때문에

정답리스트는 1~35번째 생선까지는 '1'이고, 36~49번째 생선까지는 '0'으로 구성됩니다.

코드는 다음과 같습니다.

fish_target = [1]*35 + [0]*14

 

학습 데이터와 정답 데이터가 만들어졌습다. 이제 사이킷런 패키지에서 knn 알고리즘을 구현하여 분류하는 일만 남았습니다.

우선 knn알고리즘을 구현한 클래스인 KNeighborsClassifier를 임포트합니다. 그 후 KNeighborsClassifier 클래스의 객체를 만듭니다.

코드는 다음과 같습니다.

from sklearn.neighbors import KNeighborsClassifier
kn = KNeighborsClassifier()

kn 객체를 통해 fish_data와 fish_target를 전달시켜 도미를 찾는 기준을 학습시킵니다. (=훈련시킨다. training)

 

객체를 만든 후에는 fit메소드를 통해 주어진 데이터로 학습을 한다. fit(feature data, target data)의 형태이다.

코드는 다음과 같다.

kn.fit(fish_data, fish_target)

 

훈련시킨 머신러닝 모델(여기서는 knn 알고리즘)이 잘 훈련이 되었는지 확인하는 것은 score()메소드를 사용하면 됩니다.

score()메소드는 0~1 사이의 값을 반환하는데 0에 가까울 수록 못 맞추고, 1에 가까울 수록 잘 맞춘다는 뜻입니다.

(target과 Knn을 통한 예측값을 비교함으로써 score메소드에 적용된다)

코드는 다음과 같습니다.

kn.score(fish_data, fish_target)

코드를 실행해보면 결과로 1.0이 나오는데 이는 모든 fish_data에 대한 예측값이 target과 모두 맞았다는 것을 의미합니다.

이 값을 정확도(accuracy)라고 합니다.

 

5.새로운 데이터에 대한 예측

모델이 잘 만들어졌으면 새로운 데이터를 추가하여 도미와 빙어 중 어디에 속하는지 맞추는지 보고싶을 것입니다. 

이게 우리의 궁극적인 목표이기 때문입니다.

 

새로운 데이터의 등장!

초록색 삼각형의 데이터가 추가되었다고 가정해봅시다.

우리 인간은 직관적으로 파란색 원의 데이터와 가깝기 때문에 이 삼각형 데이터는 도미라고 판단하게 됩니다.

비슷한 원리(이번 장 초반부분에서 knn 원리를 설명했다.)로 knn알고리즘 역시 삼각형 데이터를 도미라고 판단하게 될 것입니다.

이를 구현한 코드는 다음과 같습니다.

kn.predict([[30, 600]])

fit()메서드와 마찬가지로 predict()메서드 역시 이차원 리스트를 전달해야 합니다.

코드를 실행해보면  array([1])이 나옵니다. 앞서 우리는 target을 정할 때 '1'을 도미로 가정했습니다.

따라서, 길이 30 무게 600에 대한 새로운 데이터는 knn모델을 통해 도미라고 판단하게 되는 것입니다.

 

그러면 knn알고리즘에서 필요한 것은 데이터를 모두 가지고 있는 것이 전부 아닌가요?

맞습니다. knn에서 새로운 데이터를 예측하기 위해서는 가장 가까운 직선거리에 어떤 데이터가 있는지 살피기만 하면 되기 때문입니다.

결국, fit()메서드에 전달한 데이터를 모두 저장해놓고 있다가 새로운 데이터가 등장하면 가장 가까운 n개의 데이터를 참고하여 데이터를 구분하는 것입니다.

 

여기서 참고할 가까운 n개의 데이터는 어떻게 정해야 할까요?

이는 개발자가 정하기 나름입니다. 데이터가 어떻게 이루어져 있는지 먼저 확인을 한 후에 n을 정하는 것이 바람직합니다.

KNeighborsClassifier 클래스의 기본 default 값으로 n은 5입니다.

n을 바꾸고 싶으면 다음과 같은 형태로 바꾸면 됩니다.

knn = KNeighborsClassifier(n_neighbors = 'n') #n은 직접 정하기

 

 

생각을 해봅시다. 우리가 다룬 데이터(도미, 빙어 데이터)에서 n을 49로 두면 정확도가 어떻게 될까요?

.

.

.

정답은 약 71.4 (35 / 49) % 입니다. 49개 이웃을 모두 참고하게 된다면 도미가 35개로 다수를 차지하기 때문에 모든 데이터를 빙어라고 판단하게 되기 때문입니다. 그렇게 되면 도미는 도미로 예측하여 문제가 없지만 빙어를 도미로 예측하여 14개의 빙어 데이터를 도미로 잘 못판단하게 됩니다. 더욱이 어떠한 새로운 데이터를 넣더라도 모두 도미로 예측할 것입니다. 

 

n을 정하는 것이 knn에서 정말 중요하다고 할 수 있겠네요.