ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • SAM : Sharpness-Aware Minimization For Efficiently Improving Generalization
    2021_Project/generalization 2021. 10. 12. 15:12

    저번 실험에서 training loss 및 validation loss(=MAE) 가 3 초반으로 낮았음에도 불구하고 다른 domain에서의 external validation dataset에 대해서는 잘 예측하지 못함(MAE=5.xx)을 확인하였다. 이는 나의 모델의 generalization이 부족함을 의미하며, 이를 해결하기 위해 domain adaptation이 필요하다는 생각을 하였다.

     model generalization에 대한 방법을 찾던 도중, 교수님께서 최근 발표된 optimization 기법인 SAM에 대해 이야기 하셨으며 이를 한번 내 모델에 적용시켜보라고 권유하셨다. 따라서 이번 포스트에서는 SAM이 무엇인지에 대해 소개하고 brain age prediction model에 적용한 결과에 대해 이야기하도록 하겠다.

     

    SHARPNESS - AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION

    SAM optimization 기법은 2021 ICIR conference paper로 publish되었다. 간단히 이야기하면 Sharpness-Aware Minimization(SAM)은 neighborhood가 똑같이 균일하게 low loss를 갖는 parameter를 추구하기 때문에 SAM은 다양한 데이터셋에 대해 model generalization을 향상시킨다(label noise에 대해 robust함을 제공한다).

     

    1. Introduction

     

    많은 모던 뉴럴 네트워크는 training data를 암기하는 경향이 강하며 따라서 쉽게 overfit할 수 있다. 그러므로  model 학습 시 paramerter들이 training set을 넘어서 generalization할 수 있는 방향으로 선택되어야 한다. loss landscape의 geometry(특히 minima의 flatness)와 모델 일반화의 connection은 이론, 경험적으로 많이 연구되었으며 본 논문에서는 loss landsacpe의 geometry를 leveragegksms model generalization approach에 대해 소개한다. 논문에서 요약한 SAM의 특징은 아래와 같다.

    • SAM은 loss value와 loss sharpness를 동시에 minimize하면서 model generalization을 향상시키는 novel procedure이다. SAM function은 parameter가 그 자신만 low loss value를 갖는것이 아닌 주위 neighborhood하기 low loss value를 갖도록 한다.
    • SAM은 label noise에 대해 robust함을 준다.

     2. Sharpness-Aware Minimization (SAM)

     dataset을 학습시킬때 우리는 2종류의 Loss에 대해 정의를 내리고 시작한다. 먼저 training set loss인 Ls(w)는 전체 데이터 D중 S만큼 sampling된 데이터만을 가지고 관측한 Loss이며 population loss인 Ld(w)는 전체적인 global한 loss를 의미한다. 우리는 한정된 학습데이터만을 갖고 있기 때문에 Ls(w)를 minimize하도록 하지만 실제로는 robust한 모델을 만드는 것이 우리의 goal이기 때문에 Ld(w)가 최소화되는 것이 이상적이다라고 할 수 있다.

     하지만 불행하게도 modern overparameterized models들에 있어서 typical한 optimization 기법들은 test 환경에서 suboptimal한 performance를 내기 쉽다. 특히 modern model들에 있어서 Ls(w)는 보통 non-convex모형을 지니기 때문에 수많은 local, global minima를 가지며, 이들은 각각 매우 다른 generalization performance(Ld(w))를 초래한다.

     이를 해결하기 위해 논문에서는 Ls(w)를 최소화하는것 대신 neighborhoods들이 low loss뿐만 아니라 low curvature를 가지도록 parameter value를 찾았으며 그 결과는 figure 1에서 확인 가능하다.

    Ld(w)는 오른쪽 수식보다 크거나 작다고 표현될 수 있는데 (여기서 h는 regularization term이다) 사실상 저 부분이 바로 우리가 원하는 sharpness loss를 수식으로 표현한 것이다. 우리는 이를 inequality의 right hand side부분을 아래와 같이 다시 적으면서 확인할 수 있다.

    이는 위의 수식에 Ls(w)를 더하고 빼준거라서 결과값은 같음을 알 수 있다. 이제 이 부분을 자세히 살펴보면, 먼저 대괄호가 있는 부분은 Ls의 기울기로 Ls의 sharpness를 측정해주는 역할을 한다. sharpness term(대괄호로 묶여 있는 부분)은 이제 training loss value와 regularizer가 합해지게 된다. 보통 저 h부분을 L2 regularization term으로 사용하며 그렇게 된다면 SAM problem은 결국 아래를 해결하 것으로 수식화 할 수 있다.

    그렇다면 실제 모델에 SAM을 어떻게 적용시킬 수 있을까? 친절하게도 논문의 저자들은 SAM을 쉽게 implementation할 수 있도록 open source code를 github에 공개하였다. ( https://github.com/google-research/sam )

     

    GitHub - google-research/sam

    Contribute to google-research/sam development by creating an account on GitHub.

    github.com

    실제로 pytorch 에서 implementation하는 방법은 아래와 같다.

    !pip install sam
    !pip install easydict
    
    import easydict
    args = easydict.EasyDict({
        "adaptive" : True,
        "batch_size" : 8,
        "depth" : 18,
        "dropout" : 0.1,
        "epochs" : 300,
        "label_smoothing" : 0.0,
        "learning_rate" :  0.005,
       # "momentum" : 0.9,
        "rho" : 2.0,
        "weight_decay" : 0.0001,
        "width_factor" : 8
    })

    먼저 github에서의 코드를 사용하기 위해 param들을 지정해줘야하므로 easydict으로 묶어준다.

    class SAM(torch.optim.Optimizer):
        def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
            assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
    
            defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
            super(SAM, self).__init__(params, defaults)
    
            self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
            self.param_groups = self.base_optimizer.param_groups
    
        @torch.no_grad()
        def first_step(self, zero_grad=False):
            grad_norm = self._grad_norm()
            for group in self.param_groups:
                scale = group["rho"] / (grad_norm + 1e-12)
    
                for p in group["params"]:
                    if p.grad is None: continue
                    self.state[p]["old_p"] = p.data.clone()
                    e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                    p.add_(e_w)  # climb to the local maximum "w + e(w)"
    
            if zero_grad: self.zero_grad()
    
        @torch.no_grad()
        def second_step(self, zero_grad=False):
            for group in self.param_groups:
                for p in group["params"]:
                    if p.grad is None: continue
                    p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"
    
            self.base_optimizer.step()  # do the actual "sharpness-aware" update
    
            if zero_grad: self.zero_grad()
    
        @torch.no_grad()
        def step(self, closure=None):
            assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
            closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass
    
            self.first_step(zero_grad=True)
            closure()
            self.second_step()
    
        def _grad_norm(self):
            shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
            norm = torch.norm(
                        torch.stack([
                            ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                            for group in self.param_groups for p in group["params"]
                            if p.grad is not None
                        ]),
                        p=2
                   )
            return norm
    
        def load_state_dict(self, state_dict):
            super().load_state_dict(state_dict)
            self.base_optimizer.param_groups = self.param_groups

    이후 github에 저장되어있는 SAM class code를 복붙한 다음

    base_optimizer = torch.optim.Adam() 
    optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=50, gamma=0.1)

    단순하게 나의 초기 optimizer에 SAM을 적용시킨 것을 최종 optimizer라고 지정하면 된다!

     

    하지만 안타깝게도.. 나의 경우에는 SAM을 적용시켰을 때 오히려 test set에 관해 MAE가 더 증가하였으며 다른 일반화 방법을 시도해봐야겠다.

    댓글

Designed by Tistory.