ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Pytorch] 다른 조합의 Train, Validation set 나누기
    전반적인 딥러닝 기법 2021. 9. 15. 20:20

     보통 feature 기반의 deep learning algorithm의 경우, 좀 더 global한 solution을 위해 cross-validation을 사용한다. 이는 나의 model이 내 특정 test dataset에 overfit하는 것을 방지해주며 모든 데이터를 학습에 사용하므로 data 부족으로 인한 underfitting 문제도 어느 정도 해결해준다. 통상적으로 brain age prediction에 관련해서 대다수의 논문들이 10-fold cross validation을 사용한다. 하지만 이는 비교적 running time이 짧은 모델에서만 쓰일 뿐, 학습시간이 nn일 걸리는 image 기반의 모델에서는 CV를 하지 않는 추세이다.

     나 역시 현재 RTX A6000으로 3D ResNet을 돌릴 때 300 epoch에 대해 1일 3시간이 걸리기 때문에 10-fold 혹은 8-fold CV를 하기는 어렵다. 그렇다면 내 model이 나의 validation set에 대해서 overfitting되고 있다는 것을 어떻게 확인할 수 있을까?

     

     물론 만약 exclusive한 다른 test dataset이 존재한다면, 단순하게 test set에 대해 모델을 평가해주면 된다. 하지만 이는 매우 이상적인 경우로 대부분은 데이터 부족으로 인해 test dataset까지 만들 여유가 없을 것이다. 

     가장 쉽고 빠르게 확인하는 방법은 CV를 2번만 하는 것이다. 현재 우리가 갖고 있는 데이터는 아래 그림과 같이 train:validation=8:2의 비율을 가지고 있다.

    따라서 2번째 CV에서는 그림과 같이 이전 학습 시 train dataset에 쓰였던 데이터 중 일부를(전체의 20%) 새로운 validation set으로 사용한다. 이를 Pytorch 환경에서 코드로 작업하면 아래와 같다.

    import torch
    
    dataset = MRIdata3d(img_size, train_path)
    num_dataset = len(dataset)
    train_num = int(num_dataset*0.8)
    valid_num = int(num_dataset - train_num)
    train_set, valid_set = torch.utils.data.random_split(dataset, [train_num, valid_num], generator=torch.Generator().manual_seed(42))

    여기서는 실행 횟수에 상관없이 dataset을 같은 방식으로 나누고자 random_seed를 설정해주었다. dataset에 대해 80%, 20%의 비율로 train set과 validation set을 만들어준 다음, DataLoader를 사용하여 1번째 CV dataloader를 만들어준다.

    from torch.utils.data import DataLoader, Dataset
    
    train_loader = DataLoader(dataset=train_set, batch_size = 32, shuffle=True, drop_last=True)
    valid_loader = DataLoader(dataset=valid_set, batch_size = 32, shuffle=False, drop_last=True)

    2번째 CV를 할 시에는 another valid_num을 train_set에서 뽑아야 할 것이다. 따라서 2nd CV에서 validation을 part_valid라 하자면 이는 초기 train set의 1/4만큼 랜덤하게 추출한 값이 될 것이며, 나머지인 part_train과 valid_set을 합한 것이 우리의 새로운 train set이 될 것이다. part_train + valid_set 과정에서 dataloader를 합해주는 함수는 pytorch가 제공해주는 ConcatDataset을 사용하면 된다.

    part_valid_num = int(train_num/4)
    part_train_num = int(train_num - part_valid_num)
    
    part_valid, part_train = torch.utils.data.random_split(train_set, [part_valid_num, part_train_num], generator = torch.Generator().manual_seed(42))
    real_train = ConcatDataset([part_train, valid_set])
    
    train_loader = DataLoader(dataset = real_train, batch_size=32, shuffle=True, drop_last=True)
    valid_loader = DataLoader(dataset = part_valid, batch_size=32, shuffle=False, drop_last=True)

    댓글

Designed by Tistory.