ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Cost function for Segmentation
    segmentation 2022. 2. 4. 18:45

    이번 포스트에서는 segmentation에서 training 시 사용할 수 있는 cost function에 대해 소개하고자 한다.

     

    Segmentation은 결국 모든 pixel에 대해서 per-pixel-classification 문제로 해석 가능하기 때문에 가장 간단한 cost function으로는 Cross-entropy를 생각 할 수 있다. Cross-entropy는 우선 각 pixel에서 시행되며, 이렇게 얻은 독립적인 N개의 cross entropy값은 추후 averaged된다. 하지만 주변의 error값을 고려하지 않기 때문에 별로 좋지 않으며, 특정 class가 training data에서 많이 차지하고 있을때 성능이 현저히 저하된다. Segmentation에서 사용하는 cross entropy 값은 아래 수식과 같다.

    Class간의 unbalance를 조절하기 위해 등장한 cost function이 이제 Weighted Cross Entropy(WCE)이다.

     

    1. Weighted Cross Entropy (WCE)

    이는 말그대로 손실함수에 가중치를 부여하는 것이다. 보다 부족한 class에 가중치를 곱해주어 loss값을 키워주므로 모델 학습시 그 부분에 더 집중할 수 있게 해준다. 2 class를 가진 경우, WCE의 수식은 아래와 같다.

     

    수식을 이해하기 쉽게 예시를 들어 설명하도록 하겠다. 예를 들어 좌측 그림과 같은 input label이 있다고 가정해보자. (우리는 초록색 영역을 segmentationg하는 task를 수행 중이다.) Label(pi=1)에 해당되는 부분(초록)이 적을 경우 cross entropy를 이용하면 해당 p값은 최적의 cost 값을 찾을 때 영향을 적게 준다. 따라서 label(vessel) size(#voxel)이 작을 때 더 큰 w값을 부여해줌으로 class간의 unbalance를 조절해준다. (w가 증가하면 p가 감소하니까..)

     

     

     

     

    2. Dice Loss (DL)

    보다 더 많이 사용하는 cost funtion은 Dice coefficient(두개의 binary set사이의 유사도를 측정하기 위해 쓰인다)에서 유래한다. Dice loss의 수식은 아래와 같으며 분모가 0이 되는 것을 방지하기 위해 epsilon값을 넣어준다. 또한 모델은 loss가 작아지는 방향으로 학습이 되기 떄문에 1에서 빼준다.

     

    3. Generalized Dice Loss (GDL)

    Multiple class segmentation의 경우 평가 방법은 아래와 같다.

     

     

    좌측 그림에서와 같이 ∑ri가 작을 경우 (Label(pi=2)에 해당하는 부분이 작을 경우) GDL을 이용하면 해당 label은 최적의 cost를 찾을 떄 영향을 적게 준다.

    voxel 개수가 적은 label에 더 큰 w값을 부여한다. (∑ri 감소 = wi 증가)

    댓글

Designed by Tistory.