귀퉁이 서재

머신러닝 - 14. 에이다 부스트(AdaBoost) 본문

머신러닝

머신러닝 - 14. 에이다 부스트(AdaBoost)

Baek Kyun Shin 2019. 10. 4. 13:12

본 챕터에서는 부스팅 기법 중 가장 기본이 되는 AdaBoost에 대해 알아보겠습니다. 부스팅에 대해서 잘 모르신다면 '머신러닝 - 11. 앙상블 학습 (Ensemble Learning): 배깅(Bagging)과 부스팅(Boosting)'을 참고해주시기 바랍니다.

아래 AdaBoost는 StatQuest의 AdaBoost, Clearly Explained를 요약 정리한 글입니다.

AdaBoost 개요

아래와 같이 노드 하나에 두개의 리프(leaf)를 지닌 트리를 stump라고 합니다.

출처: StatQuest

AdaBoost는 아래와 같이 여러 개의 stump로 구성이 되어있습니다. 이를 Forest of stumps라고 합니다.

출처: StatQuest

트리와 다르게 stump는 정확한 분류를 하지 못합니다. 여러 질문을 통해 데이터를 분류하는 트리와 다르게, stump는 단 하나의 질문으로 데이터를 분류해야하기 때문입니다. 따라서 stump는 약한 학습기(weak learner)입니다. 

이전 챕터에서 살펴봤던 랜덤 포레스트는 여러 개의 트리의 결과를 합산해서 최종 결과를 냅니다. 다수결의 원칙을 통해 말이죠. 최종 분류를 하는데 있어 각각의 트리는 동등한 가중치를 가지고 있습니다. 

출처: StatQeust

하지만 AdaBoost에서는 특정 stump가 다른 stump보다 더 중요합니다. 즉, 더 가중치가 높습니다. 아래의 그림에서 보는 것처럼 크기가 큰 것은 가중치가 더 높은 stump를 뜻합니다. 여기서 가중치가 높다는 것을 Amount of Say가 높다고 표현합니다. 결과에 미치는 영향이 크다는 뜻입니다.

출처: StatQuest

또한, 첫번째 Stump에서 발생한 error는 두번째 Stump의 결과에 영향을 줍니다. 두번째 Stump에서 발생한 error 역시 세번째 Stump의 결과에 영향을 줍니다. 그렇게 마지막 Stump까지 줄줄이 영향을 줍니다.

정리하자면, AdaBoost는 다음과 같은 3가지 특징을 가지고 있습니다.

1. 약한 학습기 (Weak Learner)로 구성되어 있으며, 약한 학습기는 Stump의 형태입니다.
2. 어떤 Stump는 다른 Stump보다 가중치가 높습니다. (Amount of Say가 큽니다.)
3. 각 Stump의 error는 다음 Stump의 결과에 영향을 줍니다.

AdaBoost 작동 원리

이제, AdaBoost의 작동 원리에 대해서 알아보겠습니다.

출처: StatQuest

Chest Pain, Blocked Arteries, Patient Weight에 따른 Heart Disease 여부에 대한 데이터입니다. 맨 처음 Sample Weight는 8개의 데이터 모두 동일하게 1/(total number of samples) = 1/8입니다. 모든 sample의 가중치가 1/8로 동일합니다. 이제 각각의 feature가 target value(여기서는 Heart Disease)에 미치는 영향에 대해 살펴보겠습니다. 

다음은 Chest Pain과 Heart Disease와의 관계입니다.

출처: StatQuest

단순하게, Chest Pain이 Yes이면 Heart Disease도 Yes라고 판단하는 모델입니다. 총 8개의 데이터 중 Chest Pain이 Yes인 데이터(즉, Heart Disease를 Yes라고 판단한 데이터)는 5개, No인 데이터는 3개입니다. Heart Disease를 Yes라고 판단한 것 중 올바르게 판단한 것은 3개, 틀리게 판단한 것은 2개입니다. 반대로, Heart Disease를 No라고 판단한 것 중 올바르게 판단한 것은 2개, 틀리게 판단한 것은 1개입니다. 따라서 위와 같이 구분이 되었습니다. 참고로 여기서는 모든 Sample Weight가 1/8로 같아서, Sample Weight를 무시해도 됩니다. Sample Weight에 관해서는 다음 단계에서 활용합니다. 

이젠, Blocked Arteries와 Heart Disease와의 관계입니다.

출처: StatQuest

마지막으로, Patient Weight와 Heart Disease와의 관계입니다.

출처: StatQuest

이제 각 Stump의 지니 계수를 구합니다.

참고로, 이때 Weight>176을 분기 기준으로 했습니다. 1) 몸무게를 오름차순으로 정렬한 뒤, 2) 인접한 몸무게의 평균을 구한 다음, 3) 각 평균 값마다 지니 불순도를 계산합니다. 4) 가장 작은 지니 불순도를 갖는 인접 몸무게 평균값을 분기 기준으로 잡은 겁니다. 그 값이 176입니다. 이 설명이 헷갈린다면 StatQuest: Decision Trees영상을 14:04초부터 보시기 바랍니다.

출처: StatQuest

마지막 Stump의 지니 계수가 가장 작기 때문에, 마지막 Stump를 forest의 첫 Stump로 지정합니다. 이 Stump가 최종 결과 예측에 얼마만큼 중요한지 알아보겠습니다. 

Amount of Say 구하기

바로 위 그림에서 틀리게 분류한 것이 No Heart Disease의 Incorrect로 1개밖에 없습니다. 따라서 Total Error = 1/8입니다.

출처: StatQuest

모든 Sample Weights의 합은 1이기 때문에, Total Error는 0과 1 사이의 값을 갖습니다. Total Error가 Amount of Say를 결정합니다. Amount of Say는 최종 분류에 있어서 해당 Stump가 얼마만큼의 영향을 주는가를 뜻합니다. Amount of Say를 구하는 공식은 아래와 같습니다.

출처: StatQuest

Amount of Say를 그래프로 그려보면 아래와 같습니다. X축은 Total Error, Y축은 Amount of Say입니다. Total Error가 0이면 Amount of Say는 굉장히 큰 양수이고, Total Error가 1이면 Amount of Say는 굉장히 작은 음수가 됩니다. 따라서 Total Error가 0이면 항상 올바른 분류를 한다는 뜻이고, 1이면 항상 반대로 분류를 한다는 뜻입니다. Total Error가 0.5일 때는 Amount of Say가 0입니다. 동전을 던지는 것과 마찬가지로 의미가 없다는 뜻입니다.

출처: StatQuest

다시 Stump 문제로 돌아가보겠습니다. Total Error = 1/8이라고 했으므로,

= 0.97입니다. 그래프 상에서 표현해보면, Total Error가 1/8이고, Amount of Say = 0.97인 아래 지점입니다.  

출처: StatQuest

 

Chest Pain Stump의 Amount of Say를 구하는 방법입니다. 사실 이 단계는 할 필요가 없지만 Amount of Say를 구하는 방법을 다시 한번 설명하려고 추가했습니다.

출처: StatQuest

이 Stump의 Total Error = 1/8 + 1/8 + 1/8 = 3/8입니다. Amount of Say를 구해보면,

= 0.42입니다.

Blocked Arteries도 동일한 방식으로 Amount of Say를 구해줄 수 있습니다.

샘플 가중치 설정

Adaboost에서는 하나의 Stump가 잘못 분류한 sample에 대해서는 다음 Stump로 넘겨줄 때 가중치를 더 높여서 넘겨줍니다. 그래야 다음 Stump에서 해당 Sample에 더 집중해서 올바로 분류해주기 때문입니다. 맨 처음 Weight Stump에서는 아래 빨간 네모 안에 있는 sample만 잘못 분류했습니다. 따라서 해당 sample의 weight를 1/8보다 크게 하고, 나머지 sample의 weight는 1/8보다 작게 해서 다음 Stump로 넘겨줍니다. 다음 Stump로 넘겨줄 때의 새로운 sample weight를 구하는 공식은 아래와 같습니다. (이전 Stump에서 잘못 분류된 sample인 경우)

출처: StatQuest

(이전 stump에서 잘못 분류된 sample의) New Sample Weight  = (1/8) * e^(0.97) = (1/8) * 2.64 = 0.33 입니다. 기존의 sample weight = 1/8 = 0.125였는데 이보다 더 높아졌습니다.   

이전 Stump에서 잘 분류된 sample인 경우의 새로운 sample weight 구하는 공식은 아래와 같습니다.

출처: StatQuest

amount of say에 - 부호만 붙이면 됩니다. 

(이전 stump에서 잘 분류된 sample의) New Sample Weight = (1/8) * e^(-0.97) = (1/8) * 0.38 = 0.05 입니다. 기존의 weight인 0.125보다 더 작아졌습니다.

다시 설명하자면, 이전 Stump에서 잘못 분류한 sample에는 sample weight를 높여주고, 이전 Stump에서 제대로 분류한 sample에는 sample weight를 줄여줍니다. 그래야 다음 Stump에서 이전 Stump에서 잘못 분류한 것에 더 집중해서 올바로 분류해주기 때문입니다. 새로 구한 sample weight는 아래와 같습니다.

출처: StatQuest

New Sample Weight는 공식에 의해 구한 weight입니다. 단, New Sample Weight를 다 더했을 때 값은 0.68로 1이 되지 않습니다. Sample Weight의 합은 항상 1이 되어야 하므로 오른쪽 Weight처럼 정규화해줍니다. 정규화하는 방법은 각 Weight를 모든 Weight의 합으로 나누어주는 것입니다. 

0.05 / 0.68 = 0.07 입니다. 그리하여 정규화한 Weight를 맨 오른쪽에 적었습니다. 이제 기존 Weight는 모두 지우고 정규화한 새로운 Weight만 보겠습니다.

 

출처: StatQuest

그다음 단계는 샘플링으로 새로운 테이블을 만드는 것입니다. 0부터 1까지 숫자를 무작위로 뽑습니다. 이때 0~0.07 사이의 숫자가 나오면 첫번째 sample을 선택합니다. 0.07~0.14 사이의 숫자가 나오면 두번째 sample을 선택합니다. 0.14~0.21 사이의 숫자가 나오면 세번째 sample을 선택합니다. 0.21~0.70 사이의 숫자가 나오면 네번째 sample을 선택합니다. 눈치 채셨겠지만 sample weight의 누적값에 해당하는 sample을 뽑는 것입니다. 그렇게 원래 테이블의 sample수와 똑같은 sample 수를 가진 새로운 테이블을 구성합니다. 뽑힌 sample들을 보니 중복되는 것도 있습니다. 원래 테이블에서 sample weight가 0.49인 sample이 4번이나 뽑혔습니다. 당연히 0.21 ~ 0.70 사이의 숫자가 나올 확률이 크니, 해당 sample도 많은 겁니다.

출처: StatQuest

이제 원래 테이블은 지우고, 샘플링한 새로운 테이블을 가져옵니다. 모든 sample의 weight는 다시 1/8로 통일합니다. 첫 Stump에서 잘못 분류했던 sample이 4번이나 포함된다고 했습니다. sample weight는 1/8로 동일하더라도 똑같은 데이터가 4개가 있어서 실제로는 4/8의 weight를 갖는 것입니다. 이는 처음에 잘못 분류를 했기 때문에, 그다음에는 weight를 높여서 제대로 분류하기 위함입니다. weight가 높아지니 해당 sample에 가중치를 더 두고 분류를 할 것입니다.

출처: StatQuest

다시 맨 처음했던 것과 같은 방식으로 진행하면 됩니다.

최종 분류

이렇게 여러 차례 진행하면 아래와 같이 각 Stump마다의 Amount of Say가 나옵니다. 왼쪽은 Heart Disease가 있다고 판단한 Stump이고, 오른쪽은 Heart Disease가 없다고 판단한 Stump라고 합시다. 각 Stump의 Amount of Say를 더하면 Total Amount of Say가 나옵니다. 이 경우 Heart Disease가 있다는 것의 Total Amount of Say가 2.7로 더 큽니다(0.97+0.32+0.78+0.63 = 2.7). 따라서  최종적으로 Heart Disease가 있다고 분류를 할 수 있습니다.

출처: StatQuest

각 Stump마다 분류를 해주지만 Stump 하나의 분류력은 굉장히 약합니다. 앞서 말씀드린 것처럼 Stump는 Weak Learner입니다. 하지만 여러 Stump의 결과를 종합하면 강한 학습기(Strong Learner)가 됩니다. 또한, 각 Stump 마다 Amonut of Say가 다릅니다. 각 Stump의 Amount of Say를 합치면 Total Amount of Say가 나오고 이를 통해 최종 분류를 해주는 것입니다.

Reference

StatQuest: AdaBoost, Clearly Explained

14 Comments
  • 프로필사진 123 2020.03.26 12:30 'weight가 높아지니 당연히 더 집중해서 분류를 하겠지요.'가 구체적이지 못한 것 같습니다.
  • 프로필사진 Baek Kyun Shin 2020.03.27 21:04 신고 안녕하세요! 애매한 문장을 잘 지적해주셔서 감사합니다.
    해당 문장은 다음과 같이 바꾸었습니다.
    "처음에 잘못 분류를 했기 때문에, 그 다음에는 weight를 높여서 제대로 분류하기 위함입니다. weight가 높아지니 해당 sample에 가중치를 더 두고 분류를 할 것입니다."
  • 프로필사진 11 2021.04.07 11:08 안녕하세요. 올려주신 글 감사히 잘 보고 있습니다.
    글을 읽던 중 샘플 가중치 설정 단락 밑에 첫 번째 줄 'Adaboost에서는 하나의 Stump가 잘못 분류한 sample에 대해서는 다음 Sutmp로 넘겨줄 때 가중치를 더 높여서 넘겨줍니다.' 에서 Sutmp 오타가 발견되어 댓글 남깁니다.

    그리고 궁금한 점 한 가지 있습니다.
    "이전 Stump에서 잘못 분류된 sample의 경우 sample weight를 증가 시켜주고 이전 Stump에서 잘 분류된 sample의 경우 sample weight를 감소 시켜줍니다. 그래야 다음 Stump에서 이전 Stump에서 잘못 분류한 것에 더 집중을 해서 올바른 분류를 해주기 때문입니다."
    위 글 중 다음 Stump에서 이전 Stump에서 잘못 분류한 것에 더 집중을 해서 올바른 분류를 해준다고 하셨는데 어떻게 올바른 분류로 바뀌는지 과정이 궁금합니다. 만약 분류 규칙이 바뀐다면 어느 기준으로 분류 규칙이 바뀌는지 잘 이해가 되지 않아 질문 드립니다.
  • 프로필사진 Baek Kyun Shin 2021.04.07 22:13 신고 감사합니다 오타 수정했습니다!

    음.. 어떤 게 이해가 안 가시는지 잘 모르겠네요 ㅜㅜ 본문에 잘못 분류된 샘플과 잘 분류된 샘플의 가중치를 구하는 공식, 그리고 그 가중치를 활용해서 다시 분류하는 프로세스를 설명해놨습니다. 분류 방법은 계속 동일합니다. 가중치만 바꿔 제대로 분류를 못한 샘플을 더 잘 분류하도록 모델을 튜닝하는 것입니다. '샘플 가중치 설정' 제목으로 돼 있는 부분 전체에 프로세스를 설명해 놓았습니다.

    제가 질문을 이해 못하는 것일 수도 있으니 혹시 추가로 궁금하신 게 있으시면 말씀해주세요 ^^
  • 프로필사진 무민 2021.04.08 14:38 lgbm 논문으로 공부하다가 SGB와는 다르게 adaboost 기반은 down sampling 방식을 적용하기 힘들다는 말에 의문을 가지고
    공부를 하다가 여기까지 흘러들어왔네요
    이 글을 읽고 weight가 부여되고 random sampling 방식으로 샘플들이 갱신이 되는 것을 보고 의문점이 해소가 되었습니다.
    좋은 설명 글 감사드립니다.
  • 프로필사진 Baek Kyun Shin 2021.04.08 21:14 신고 읽어주셔서 감사드립니다 :-)
  • 프로필사진 연호 2021.04.12 00:38 머신러닝에 대해 구글링하다가 우연히 보고 정말 재밌고 유익하게 글 감사히 보고 있습니다!!
    혹시 마지막 최종 분류에서 '최종적으로 Heart Disease가 있다고 분류를 할 수 있습니다.' 라는 문장은
    어떤 데이터를 기준으로 나온 결론일까요?? 비교 대상이 따로 없이 그냥 심장 질환이 있다고 분류되는 결론에서
    약간 이해가 되지 않습니다 ㅠㅠ
  • 프로필사진 Baek Kyun Shin 2021.04.15 00:25 신고 재밌고 유익하게 봐주셔서 감사드립니다.

    심장병이 있는 경우 Total amount of say가 2.7이고, 심장병이 없는 경우 Total amount of say가 1.2입니다. Total amount of say가 더 큰 쪽이 최종 분류값이 됩니다. 그래서 결론이 심장병이 있다고 분류한 것입니다.
  • 프로필사진 보근 2021.11.02 16:40 스텀프는 몇개까지 생성을 하나요? 사용자 지정값일까요?
  • 프로필사진 Baek Kyun Shin 2021.11.04 13:52 신고 네 파라미터로 지정합니다.
  • 프로필사진 lee2 2022.03.15 00:08 정말 유익한 글입니다. 많이 배우고있습니다.
    질문이있습니다.
    1.patient weigt구분시 176은 어떤 기준으로 나눈건지요?

    2.가중치를 부여한 데이터셋으로 다시 셋팅후 동일한 weight를 부여하고 다시 gini계수가 가장 낮았던 피처를 찾아서 스텀프를 구성하는 건가요?

    3.최종분류에서 질병을 구분한 amount of err가 이해가 안됩니다. (질병을 둘로 나뉜 amount of err 어떻게 구해지는 건가요?)
  • 프로필사진 Baek Kyun Shin 2022.03.15 20:26 신고 질문해주셔서 고맙습니다 ^^

    Weight>176을 분기를 정하는 기준은 이렇습니다. 1) 몸무게를 오름차순으로 정렬한 뒤, 2) 인접한 몸무게의 평균을 구한 다음, 3) 각 평균 값마다 지니 불순도를 계산합니다. 4) 가장 작은 지니 불순도를 갖는 인접 몸무게 평균값을 분기 기준으로 잡은 겁니다. 그 값이 176입니다. 이 설명이 헷갈린다면 https://www.youtube.com/watch?v=7VeUPuFGJHk 영상을 14:04초부터 보시기 바랍니다. 1분 가량 설명하는데 바로 이해되실 겁니다.

    최종분류에서, 각 stump마다 Amount of Say를 구해서 Amount of Say의 합이 큰 값을 최종분류값으로 정합니다. Amount of Say를 구하는 방법은 본문에서 설명을 했습니다.

    더 궁금한 점이 있다면 댓글 남겨주세요 :)
  • 프로필사진 정선호 2022.03.16 13:30 감사합니다.
    마지막에
    심장병이 있는 경우 Total amount of say가 2.7이고, 심장병이 없는 경우 Total amount of say가 1.2입니다. Total amount of say가 더 큰 쪽이 최종 분류값이 됩니다. 그래서 결론이 심장병이 있다고 분류한 것입니다.
    => 이 말이 이해가 안가는데, 테스트 데이터를 넣었을 때의 결론인가요? 아니면 분류모델을 2.7인 기준으로 설정한다는 건가요?
    total amoun = 2.7을 구해놓고 뜬금없이 왜 심장병이 없다는 게 왜 나왔는지 궁금합니다.
  • 프로필사진 Baek Kyun Shin 2022.03.17 22:41 신고 아 제가 설명을 좀 헷갈리게 했네요.
    심장병이 있다고 판단한 stump가 여러 개 있고, 심장병이 없다고 판단한 stump도 여러 개 있을 겁니다. 심장병이 있다고 판단한 stump의 amount of say를 구하면 2.7이 되고, 심장병이 없다고 판단한 stump의 amount of say를 구하면 1.2가 나온다고 합시다. 이때 에이다 부스트 모델은 모든 stump의 결과를 종합해 최종적으로 심장병이 있다고 판단을 하는 겁니다. 심장병이 있다고 판단한 stump의 amount of say가 2.7로 더 높기 때문입니다.
댓글쓰기 폼