ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Decision Tree Regressor
    2021_Project/Linear Regression 2021. 8. 5. 13:36

    Decision Tree Regressor을 설명하기 앞서 우선 Decision Tree 모델이 무엇인지 설명하도록 하겠다. Decision Tree는 일정한 기준을 질문으로 제시했을 때, '예' 혹은 '아니오'로 갈라질 수 있는 결정 모델을 의미한다.

    조건식에 따라 여러 노드들이 마치 나무처럼 형성이 되기 때문에 Decision Tree라고 불린다. 그렇다면 이 decision tree와 회귀모델이 만나면 어떻게 될까? Decision tree regressor을 먼저 설명하기 앞서 decision tree classification 모델을 살펴보는 것이 이해하기 더 쉬울 것이다.

     

    1. Decision Tree classification

    Decision tree classification에서 root node와 leaf node만으로 이루어져 있는(깊이가 1인) 가장 간단한 분류 모델을 생각해보자. 좌측 그림에서의 가로 축이 X[0] , 세로 축이 X[1]이라고 할 때 우리는 우선적으로 X[1]<=0.0596의 조건식으로 데이터를  파란영역과 빨간 영역으로 나누었다. 여기서 count는 크기 2의 배열로 각 클래스의 데이터 포인트가 몇개 존재하는지를 알려준다. 예를들어 count = [48,18]은 파란영역에 해당되는 동그라미 class가 48개, 세모 class가 18개인 셈이다.

    깊이가 너무 얕은 경우, 이처럼 misleading되는 data 개수가 많을 수 있다. 따라서 우리는 조건식을 더 추가해주어 보다 많은 구분선을 만들어준다.(아래 그림 참고)

    새로운 데이터 포인트에 대한 예측은 주어진 데이터 포인트가 분할된 영역 중 어디에 놓이는지를 확인하면 된다. 따라서 그 영역의 target 값 중 다수 (순수 노드라면 하나)인 것을 예측 결과로 한다. 하지만 tree 구조 자체가 복잡하다 보니 기본적으로 decision tree를 만들게 되면 overfitting될 확률이 매우 크다. 그럴 경우 사전가지치기와 사후 가지치기로 해결한다.

    사전 가지치기 (pre_pruning) : 트리 생성을 일찍 중단
    사후 가지치기 (post_pruning) : 트리 생성 후, 데이터 포인트가 적은 노드를 삭제하거나 병합

    만약 decisiion tree classification을 scikit-learn을 이용하여 구현한다면 tree에서 model.feature_importance를 지원해주기 때문에 쉽게 어떤 feature들이 모델 예측에 있어서 중요한지 알 수 있다. 이 feature importance을 보고 feature selection(reduction)을 시켜주어 최종 모델을 돌리는 경우도 많다.

     

    여기까지가 Decision Tree Classification에 대한 대략적인 설명이였다. 이제부턴 회귀모델에 대해 이야기하도록 하겠다.

     

    2. Decision Tree regressor

    class 를 의미하는 정수 값을 담고 있는 Decision Tree Classifier과는 다르게 Decision Tree Regressor은 leaf 노드에 실수(연속형)을 담고 있다. 결국엔 회귀모델도 분류모델과 마찬가지로 어떠한 기준선을 기준으로 값을 예측하는 것이기 때문에 트리기반의 회귀 모델들은 훈련 데이터 범위 밖의 포인트에 대해 예측을 할 수 없다.

    그렇다면 Decision Tree Regressor에서 node는 어떻게 생성하는 것일까? 회귀나무는 MSE를 낮추는 방향으로 가지를 뻗어 나간다. 2개의 feature(X1, X2)에 대해 regression을 진행한다고 가정할 때, 회귀나무의 예시는 아래 그림과 같을 것이다.

    제일 우측의 그림은 R1, R2, ... Rn으로 표현된 범주들을 시각적으로 표현한 그림이다. 만약 새로운 샘플 x에대해 추정값 f(x)를 구한다면 다음 수식과 같다.

    이는 x가 Rm 이라는 범주에 속해 있으면 Cm이라는 일정한 값을 출력해줌을 의미한다. 여기서 Cm은 Rm에 포함되어 있는 샘플들의 평균값을 의미한다. 그렇다면 Rm은 어떻게 정하는 것일까?

    먼저 특정 조건식 (X[1] <= 0.0596과 같은)을 이용하여 일정 기준으로 나눈다. 이 때 0.0596인 기준점 s를 결정하기 위해 회귀모델에서는 mse를 사용하여 이를 minimize하는 방향으로 찾게 된다.

     

    댓글

Designed by Tistory.