ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Synthesizing Tabular Data using Generative Adversarial Network : TGAN
    2021_Project/Linear Regression 2021. 8. 23. 19:32

    medical dataset을 사용하여 딥러닝을 학습할 때 가장 어려운 점은, 데이터셋의 확보가 어렵다는 것이다. 좋은 성능을 위해 다량의 학습 데이터가 필요하다는 면에서 medical domain에서의 인공지능은 challenging하다고 볼 수 있다. 이러한 데이터 부족 문제를 해결하기 위해 최근들어선 GAN을 사용하는 사람들이 늘어나고 있다. 특히 medical image dataset에 대해선 GAN으로 학습하여 만든 synthetic image와 source image를 합하여 data augmentation을 했다는 연구 결과들이 들려오고 있다. 하지만, 만약 내가 갖고 있는 데이터가 영상데이터가 아닌 다른 feature 데이터면 어떻게 할 것인가? 만약 내가 갖고 있는 데이터가 EMR 혹은 EHR dataset일 경우 data augmentation을 해주고 싶다면 어떻게 해야 할까?

    한가지 방법으로는 각 피쳐 데이터의 col값들의 pdf를 추정한 후(k-s test를 이용해서) 그 pdf로부터 random한 값을 추출해 올 수도 있다. 물론 그렇게 될 경우 각각의 추출된 random 변수들이 원본 데이터의 상관관계를 따라가도록 해주는 장치가 필요할 것이다. 하지만 데이터들의 pdf를 추정하는 작업이 매우 까다롭다는 것을 알게 되었으며 (물론 데이터가 정말 많다면 웬만해서는 가우시안을 따르겠지만 우리 데이터는 매우 적다는 가정이 있기 때문이다) 좀더 automatic한 방법으로 data augmentation을 수행해주는 방법을 알고 싶어 좀더 찾아보았고 그러던 도중 tabular dataset을 위한 GAN 연구들이 활발하게 이루어지고 있다는 것을 발견했다. 따라서 오늘은 tabular gan에 대해 다루는 논문에 대해 간략하게 소개하며 실제 dataset에 대해 어떻게 implementation하는지를 설명하겠다.

     

    Synthesizing Tabular Data using Generative Adversarial Network

     

    2018년 발표된 위 논문은 tabular data를 위한 GAN에 대해 소개하고 있다. 논문 저자들에 따르면 3가지 데이터셋에 평가해봤을 때 TGAN이 column들간의 correlation을 잘 capture해주며 데이터 증폭에 보다 효과적이라고 한다. 여기서 TGAN은 혼합된 변수 타입(다항, discrete와 continuous 한)을 지는 표 data를 생성해주는 모델이다.

     

    3. GANs for tabular data

    표 데이터를 위한 GAN을 구성하는 것은 표 데이터의 다양한 타임 (numerical, categorical, time, text, 등등) 때문에 어려운 task이다. 따라서 데이터의 형태가 다양하기 때문에 논문 저자들은 synthetic table generation task를 간소화했으며 본 논문에서는 데이터를 크게 연속형 변수와 불연속성 변수로 나누었다. (sequential 한 data는 고려하지 않는다.)  Generative Model M의 목표는 아래와 같이 크게 두 가지로 볼 수 있다.

    1. Tnew를 사용해 학습된 ML model과 T를 사용해 학습된 ML model이 real test table인 Ttest에 대해서 비슷한 accuracy를 보인다.
    2. Mutual information :  T와 Tnew 사이에 있는 임의의 두 변수 i,j 의 상호 정보량은 유사하다.

     

    3.1 Reversibla Data Transformation

    논문에서는 numerical한 변수들을 (-1, 1)의 범위의 scalar 값들로 convert해주었으며 discrete한 variable을 multinomial(수치형) distribution으로 convert해주었다.

     

    ⊙ Mode-specific normalization for numerical variables

    표 데이터에서의 numerical한 변수들은 종종 multimodal distribution을 지닌다. 따라서 논문에서는 continous 변수에서의 mode 개수를 측정하기 위해 gaussian kernel density estimation을 해주었다. 여기서 mode란 pdf에서의 봉우리 개수를 의미한다. 또한 multimodal distribution에서 value들을 효과적으로 샘플링해주기 위해 Gaussian Mixture model(GMM)을 사용하여 numerical variable의 value들을 cluster해주었다. GMM을 이용하여 sample하는 방법은 아래와 같다.

     

     

    먼저, numerical한 variable Ci가 m개의 mode를 갖는다고 가정하자. 우리는 m Gaussian distribution을 갖는 확률 분포를 modeling할 것인데, 여기서 u는 C를 m gaussian distribution으로 정규화 한 것을, v는 C를 정규화 한 후 [-0.99, 0.99]로 clipping해준 것을 의미한다.

     

    ⊙ Smoothing for categorical variables

    categorical variable을 생성하는 문제는 자연어처리에서 language generation(nlg)와 비슷한 challenge를 겪게 된다. nlg에 있어서는 사람들은 모델을 미분가능하도록 만들기 위해 강화학습이나 Gumbel sofrmax를 사용한다. 본 논문에서 다룬 categorical 변수들은 일반 nlp 문제들보단 vocabular의 사이즈가 작기 때문에 softmax를 이용하여 확률분포를 만들었다. 하지만 그전에 categorical 변수들을 onehot으로 인코딩 해준 후 noise를 추가하여 normalization 시켜준다.

    이렇게 전처리하게 되었을 때 (continous) + (category : one-hot)으로 처리 해 줄 수 있으며 이렇게 만든 u,v,d vector는 generator의 output이자 discriminator의 input이 된다.

    3.2 Model and data generation

    GAN 모델에서 discriminator D에는 MLP를 generator G에는 LSTM을 사용한다.

    ⊙ Generator

    Generator에서는 LSTM을 사용한다 했는데, 이 LSTM이 무엇인지 이야기 하기 위해서는 먼저 RNN이 무엇인가부터 설명이 필요하다.

    내가 하나의 글 지문을 읽는 상황이라고 가정해보자. 나는 글을 이해할 때 현재 보고 있는 단어 뿐만 아니라 그 전 단어들의 조합도 같이 고려할 것이다. 예를 들어 '나는 책을 보는 중이다.'라는 문장을 읽을 때도 현재 보고 있는 단어는 '보는'이지만 이전의 '나는'과 '책을'이라는 정보를 통해 보고 있는 대상이 책임을 캐치할 수 있는 것이다. 이처럼 RNN은 지속적으로 이전 단계의 정보량을 받는 구조를 의미한다. 하지만 RNN에는 큰 문제가 있다. input인 word embedding의 크기가 어떻게 될지 모르기 때문에 weight 값을 담고 있는 matrix W의 크기를 정할 수 없다. CNN의 경우 input의 크기를 256*256처럼 고정을 할 수 있지만, RNN은 그것이 불가능하다. 따라서 매 layer마다 다른 weight'값을 주는 것이 아닌 모두 동일한 weight값을 주게 된다. 이럴 경우 gradient vanishing problem이 생기게 된다. 또한 관련 정보와 그 정보를 사용하는 지점 사이 거리가 멀 경우 역전파 시 gradient가 점차 줄어 학습능력이 크게 저하된다. RNN의 이러한 문제를 해결하기 위해 LSTM이 등장하게 된다. 

    LSTM은 아래 그림과 같이 RNN의 hidden state(초록색 박스) 부분에 cell-state를 추가한 구조를 말한다. Cell state는 컨베이어 벨트와 같아서 작은 linear interaction만을 적용시키며 전체 체인을 계속 구동한다. 따라서 정보가 전혀 바뀌지 않고 그대로 흐르게 하는것을 매우 쉽게 할 수 있다. LSTM cell state 뭔가를 더하거나 없앨 있는 능력이 있는데, 능력은 gate라고 불리는 구조에 의해서 제어된다. LSTM에 대한 더 자세한 이야기는 후에 다른 포스트에서 다루도록 하겠다.

    LSTM을 어느 정도 이해했으면 이제 자연어처리에서 또 중요하게 다뤄지고 있는 Attention mechanism에 대해 간략하게 설명하겠다. Attention mechanism의 핵심을 모델로 하여금 중요한 부분만 보고 예측하게 하자!이다. Attention에서의 인코더와 디코더는 아래 설명과 같다.

    • 인코더 : 입력 시컨스를 처리하는 부분. 정보를 고정된 길이의 context vector로 압축한다. 전체 시컨스의 의미를 잘 요약해주는 역할을 한다.
    • 디코더 : 인코더가 압축한 context vector를 초기화 해준 후 target 시컨스로 변환해주는 역할을 한다. 디코더의 처음 상태는 인코더 네트워크의 마지막 state를 사용한다.

    우선 인코더는 길이 T 입력 X 받아 히든 스테이트 벡터 h 생성한다. a 출력이 어떤 입력을 많이 집중해서 보면 되는지에 대한 가중치 벡터로 모든 a값들을 더하면 1 된다. A는 attention vector를 의미한다. 디코더의 현재 시간 i의 hidden state는 이전 시간의 hidden state, 이전 시간의 디코더 출력, 그리고 현재 context vector를 입력으로 받아서 구해진다. 여기서 context vecotr는 입력 단어의 길이동안 어텐션 벡터의 가중합으로 표현되는데 이는 처음부터 끝까지 모든 단어를 보고 있음을 알 수 있는 식이다.

    다시 본 논문으로 돌아와서 Output hidden state size = nh으로 표현하며  step t마다의 LSTM input = 확률변수 z한다. Attention based context vector at 모든 이전단계의 LSTM output h1:t 대해 가중되므로 nh dimension 갖게 된다.(attention vector 학습시킨 , context vector 구함) 또한 Hidden vector s y,s,c 조합으로 나타내며, 추후 s로부터 output variable 구한다.

     

    ⊙ Discriminator

    discriminator의 경우 ㅣ- layer fully connected neural net discriminator 사용한다. 또한 Input으로는 V, U, D concat하여 사용한다. 여기서 Diversity는 mini-batch discrimination vector를, BN는 batch normalization을 의미한다.

     

    4. Evaluation setup

    평가에서는 얼마나 TGAN 표에 있는 variable 상관관계를 capture했는지, data scientist들이 실제로 synthetic data 바로 learn model 사용할 있는지에 focus한다.

    결과를 보면 real data와 synthetic data간의 평균적인 performance gap은 5.7%정도이며 아래 그림과 같이 인조 데이터가 원본 데이터의 상관관계를 어느정도 잘 따라감을 확인할 수 있다.

    여기까지가 TGAN이 어떻게 작동되고 있는지에 대한 간략한 설명이다. 그렇다면 실제로 PyTorch 환경에서 내 데이터에 대해 TGAN을 어떻게 적용시키는 것일까? 정말 편리하게도 현재 개발자들이 TGAN을 위한 라이브러리를 제공하고 있다. tabular data를 증폭시키기 위해선 우리는 간단하게 tgan을 install해주면 되기만 하다. TGAN의 전반적인 명령어들에 대한 설명은 아래 깃헙으로 들어가면 볼 수 있다.

    https://github.com/sdv-dev/TGAN

     

    GitHub - sdv-dev/TGAN: Generative adversarial training for generating synthetic tabular data.

    Generative adversarial training for generating synthetic tabular data. - GitHub - sdv-dev/TGAN: Generative adversarial training for generating synthetic tabular data.

    github.com

    여러 크기의 데이터에 대해 TGAN을 적용시킨 결과 흥미로운 사실을 발견했다. column개수가 적을수록, TGAN을 적용시켰을 때 source data의 correlation을 잘 따라가지만 반대로 데이터 사이즈가 크면 data의 context를 잘 모방하지 못한다. 예를 들어 column 개수가 10개인 데이터는 완벽하게 새로운 synthetic data를 생성해 냈지만, column 개수가 530개일 경우, synthetic data의 상관관계가 엉망이였다. 

    내가 현재 다루고 있는 데이터 사이즈는 꽤 크기 때문에, 먼저 RFE 방법을 통해 데이터의 feature 개수를 51로 줄이고 다시 TGAN을 적용시켜 보았다. RandomForestRegressor을 사용해 반복적으로 feature 개수를 줄여주었으며 아래와 같이 최종적으로 51 column이 되었다.

    내 데이터의 경우 discrete한 값을 지니는 column은 'age'밖에 없기 때문에 아래와 같이 설정한 후 3000 epoch동안 모델을 돌려준 모습이다.

    TGAN을 돌려준 결과 데이터가 아래와 같이 형성되었으며 겉으로 보기엔 어느정도 원본 데이터를 잘 따라한 것처럼 보인다.

    하지만, rfe_data와 samples의 correlation heatmap을 살펴보면 원본 데이터의 context를 전혀 학습하지 못했음을 알 수 있다.

    하지만 dacon에서 주최한 2021_시스템 품질 변화로 인한 사용자 불편 예지 공모전에서 제공한 데이터로 tgan을 똑같이 적용시켰을 때는 아래 그림처럼 correlation을 잘 따라줌을 발견하였다. 왼쪽이 원본 데이터의 heatmap이고 오른쪽이 원본데이터를 이용하여 만든 synthetic 데이터에 관한 heatmap이다.

    feature 개수가 더 많은 데이터셋에 대해 어떻게 gan을 적용시켜야 할지는 더 생각을 해봐야할 듯 하다.

     

     

    https://arxiv.org/abs/1811.11264

     

    Synthesizing Tabular Data using Generative Adversarial Networks

    Generative adversarial networks (GANs) implicitly learn the probability distribution of a dataset and can draw samples from the distribution. This paper presents, Tabular GAN (TGAN), a generative adversarial network which can generate tabular data like med

    arxiv.org

    '2021_Project > Linear Regression' 카테고리의 다른 글

    Decision Tree Regressor  (0) 2021.08.05

    댓글

Designed by Tistory.