얼레벌레

논문리뷰) Deep Residual Learning for Image Recognition 본문

AI/DL

논문리뷰) Deep Residual Learning for Image Recognition

낭낭이 2023. 8. 3. 00:04

Abstract

깊은 neural network일수록 훈련이 어렵다는 단점을 보완하기 위해 residual learning framework를 제안한다. 이를 이용해 이전보다 더 깊은 Network를 가능하게 하며 훈련이 더 쉬워지게 하고, classification뿐만 아니라 detection, segmentation 등 다른 비전분야에서도 높은 성능을 보인다는 장점을 제시한다.

1. Introduction

 image classification과 backpropagation등은 이미지 분류에 큰 돌파구로 역할을 하며 여러 feature, level 등을 분류하게 되었다. 이 과정에서 layer의 depth에 관련해 더 많은 레이어를 쌓을수록 네트워크의 훈련이 쉬워질까라는 질문이 떠오르게 되었는데, 기울기의 vanishing/exploding 문제를 야기해 처음부터 convergence를 방해할 수 있다는 점을 단점으로 제시한다. 하지만 이는 normalized initialization이나 중간의 정규화 레이어로 바로잡을 수 있다.

깊은 네트워크가 수렴해가는 과정에서 깊을수록 이점이 있는 게 아니라 degradation이 생긴다. 이는 Overfitting만이 문제가 아니며 깊은 모델에 더 많은 레이어를 추가하는 것이 높은 training error를 보일 수 있다는 것이다. 

 degradation때문에 모든 시스템이 최적화되기는 어려우므로 identity mapping(입력값과 출력값이 동일한 매핑)을 제안하는데, 단순히 identity mapping을 쌓아 layer을 깊게 만든 deeper model은 피상적인 모델보다는 낮은 training error를 가져야 한다. 

 

저자는 이 문제를 H를 직접적으로 매핑하는 것이 아닌 F를 대신 학습하게끔 하는 deep residual learning framework를 도입해 해결하고자 한다. residual mapping을 최적화하는 것이 기존의 매핑보다 더 쉽다고 가정하며 극단적으로는 residual을 0으로 만드는 게(즉 H가 x가 되는 것) identity mapping을 비선형레이어들에 맞추는 것보다 쉽다고 말한다.

 

F(x)+x를 'shortcut connections'라고 칭하는데 이는 identity mapping을 수행하며 단순히 출력값에 x를 더해주는 것이기 때문에 추가적인 파라미터나 계산복잡도가 필요하지 않다. 

 

2. Related Work

저자는 residual representation은 이미 VLAD와 Fisher Vector 등에서 사용된 것처럼 이미 기존의 vector를 인코딩하는 것보다 효율적이라고 밝혀졌고 널리 사용되고 있는 방법이라고 말한다. 또한 Multigrid나 multigrid의 대안인 hierarchical basis preconditioning은 두 개의 scale(coarser, finer scale) 사이에 residual vector를 표현하는 변수에 의존하며 이같은 방법이 residual nature을 모르는 기존의 solver보다 더 빠르게 converge한다.

Shortcut Connection에 관련한 highway networks라는 논문을 제시하는데, 이 논문도 ResNet과 마찬가지로 깊은 네트워크를 학습하기 위한 논문이지만 본논문에서는 항상 residual function을 학습하며 identity shortcut을 사용한다는 점을 다른 점으로 언급한다.

3. Deep Residual Learning

3.1. Residual Learning

 

 H(x)를 stacked layer에 의해 피팅되는 underlying mapping이라고 하고, multiple nonlinear layer이 어떤 복잡한 함수에 근접하게 된다고 가정한다면 이것이 residual function F(x) = H(x) - x에 점근적으로 근접하게 된다고 가정할 수 있다. 따라서 우리는 H(x) 대신 F(x)를 예측하면 되는데 둘다 desired function에 가정에 의한다면 점근적으로 근접하지만 훈련의 용이함이 다르다.

 introduction에서 언급했듯 identity mapping으로 layer을 더 쌓는다면 최소한 더 깊은 모델이 얕은 모델보다 크진 않은 training error을 가질 것이라고 추측할 수 있는데, 쌓여있는 layer에 identity mapping을 수행하는 것은 어려울 수 있으므로 residual learning에 따르면 identity mapping을 기본적으로 수행해 레이어의 weight를 0으로 향하게끔하여 학습난이도를 쉽게하고자 한다.

 실제로는 최적의 solution이 identity mapping일 경우는 드물다. 하지만 optimal function이 zero mapping보다 identity mapping에 가깝다면 새롭게 function을 학습하기 보다 identity mapping에 관련한 섭동(perturbation)을 찾기 쉬울 것이다.

 

3.2. Identity Mapping by Shortcuts

위의 그림과 같이 y = F(x,{Wi}) + x로 building block을 정의할 수 있고 이런 shortcut connection은 추가적인 parameter이나 computation 복잡도 없이 수행할 수 있으며 이는 학습뿐만 아니라 plain과 residual network사이에 비교에서도 중요하다. 이 때 x와 F의 차원이 같아야하는데 input/output의 채널이 변해 차원이 같지 않다면 차원을 같게 하기 위해 선형 투영 Ws를 수행한다. Ws를 차원이 같을 때도 사용할 수 있지만 identity mapping만으로도 degradation문제를 다루기는 충분하기 때문에 차원을 맞출 때만 사용한다.

 residual function F의 형태는 유동적인데 다만 F가 single layer이 되는 경우 linear layer과 비슷하기 때문에 이점을 찾을 수 없다. 따라서 F가 실제로 여러 weight값들이 중첩될 때 유의미한 성능을 보이게 된다.

 

3.3. Network Architectures

 

비교목적으로 기본적인 CNN 네트워크를 가져와 실험을 진행한다. 기본적인 네트워크는 VGG network에서 제안된 기법들에서 주로 영감을 받아 3*3 필터를 이용하며, output feature map size를 위해 필터 개수와 같은 개수의 레이어를 사용하며 feature map size가 반이 되는 경우 레이어당 시간복잡도를 보존하기 위해 필터의 개수는 2배가 되도록 한다. 별도의 pooling 없이 stride값이 2인 convolutional layer로 다운샘플링을 진행하며 네트워크의 마지막에 global average pooling layer을 사용해 1000개의 class로 분류할 수 있도록 네트워크를 설계한다. 결과적으로 본모델이 일반적인 VGG네트워크보다 더 적은 필터를 사용하면서도 복잡도 또한 낮게 나온다는 장점을 제시한다.

점선으로 표시된 경우는 차원이 일치하지 않아 projection shortcut을 수행한 경우이다.

차원이 일치하지 않는 경우 차원을 증가시키기 위해 패딩을 한 후 identity mapping을 거치거나 1*1 convolution을 통해 차원이 일치하도록 projection shortcut을 거칠 수 있다.

 

3.4. Implementation

 

imagenet을 이용해 실제 구현한 경우 이미지에서 224*224 크기로 랜덤샘플화된 크롭이나 수평 flip된 것들을 사용해 매 convolution이후 batch normalization을 사용하였다. 또한 learning rate를 0.1부터 시작해 점차적으로 줄여가는 방법을 사용하고 weight decay와 momentum 등 hyper parameter을 지정해 실험을 수행한다. 

4. Experiments

4.1. ImageNet Classification

 

 

 

4.2. CIFAR-10 and Analysis

4.3. Object Detection on PASCAL and MS COCO

 

 

Comments