SimCLR v1
이미지 데이터의 정답 label이 없는 상황에서 효과적으로 visual representation을 추출하는 SimCLR이라는 이름의 unsupervised learning algorithm을 소개합니다. SimCLR은 data augmentation을 통해 얻은 postive/ negative sample들에 대해 contrastive learning을 적용시켰으며, 성능 측면에서 supervised learning으로 학습한 모델들에 준하는 모습을 보여줍니다.
Contrastive Learning Framework
Unsupervised Learning이란 데이터의 label 없이 네트워크 모델을 학습하는 것을 의미합니다. 이전에 Computer vision 분야에서는 이미지를 임의로 회전시킨 후 모델이 회전 방향을 맞추게끔 학습시키거나, 이미지를 잘라 zigsaw 퍼즐을 만든 후 모델이 퍼즐을 풀 수 있게끔 모델을 학습했습니다. 이렇게 모델을 학습하기 위해 정의한 새로운 형태의 문제를 pretext task라고 부릅니다.
pretext task를 통해 학습하는 방식은 어느 정도의 성능을 보여주긴 했지만, 해당 pretext task를 잘 풀게끔 학습되었을 뿐 이미지의 일반적인 시각적 특징을 잡아내지는 못합니다.
이를 해결하기 위해 최근에는 contrastive learning 기반의 방식들이 많이 연구되고 있습니다. Contrastive learning이란 positive pair끼리는 같게, negative pair끼리는 다르게 구분하면서 모델을 학습하는 방식입니다. 예를 들면 노랑이라는 키워드(query)가 주어지고 사과/바나나/딸기라는 보기(key)가 있을 때, 노랑-바나나를 연결하고 사과/딸기와는 연결되지 않게 학습하는 방법입니다. 위 방식은 이전에 발표된 여러 연구들(CPC, CMC 등)에서 뛰어난 성능을 보여주었습니다.
SimCLR learns representations by maximizing agreement between differently augmented views of the same data example via s contrastive loss in the latent space.
하나의 sample에 대해 augmentation t, t’을 각각 적용해 얻은 representation 간의 agreement를 최대화 시키는 방향으로 모델이 학습된다고 합니다.
위의 그림에서 보면, 하나의 이미지(x)가 서로 다른 두 개의 augmentation 변환을 거쳐 두 개의 이미지(xi, xj)로 나눠집니다. 이렇게 변환된 두 이미지는 같은 이미지로부터 얻었기 때문에 positive pair로 정의합니다. 만약 또다른 이미지인 y로부터 yi, yj의 변환된 이미지가 나왔다고 한다면 xi과 yi(또는 yj)는 서로 다른 이미지로부터 얻었기 때문에 negative pair로 정의합니다.
변환된 각 이미지들 (xi ,xj)는 CNN 기반의 네트워크(f)를 통과하여 visual representation embedding vector(hi, hj)로 변환됩니다. 이러한 representation vector를 생성하는 network를 base encoder라고 부르며 논문에서는 ResNet을 base Encoder로 이용하였습니다.
Visual representation vector는 MLP 기반의 네트워크(g)를 통과하여 변환되고, 변환된 output(zi, zj)를 이용하여 contrastive loss를 계산합니다. MLP 기반의 네트워크는 projection head라고 부르며, 두 개의 linear layer 사이에 ReLU activation function을 넣은 구조로 구성되어 있습니다.
Encoder 및 projection head는 batch 단위로 학습하게 되는데, 만약 N의 batch size를 이용하게 된다면 각각 data augmentation을 거쳐서 2N개의 sample을 얻을 수 있습니다. 이렇게 되면 각 sample 별로 1쌍의 positive pair와 2N-2쌍의 negative pair를 구성할 수 있게 됩니다. 논문에서는 positive pair 간의 similarity는 높이고, negative pair 간의 similarity는 최소화하는 형태의 loss function을 제안하여 학습에 활용하였습니다. 해당 loss function은 NT-Xent라는 이름으로 불리며, 아래와 같은 방식으로 계산됩니다.
학습은 batch 단위로 진행되기 때문에, 많은 양의 negative pair를 구성하기 위해서는 큰 batch size를 이용해서 학습해야 합니다. 이를 위해 SimCLR은 기본적으로 4096의 batch size(총 8192개의 sample)를 이용하여 학습했으며 빠른 학습을 위해 128코어의 google cloud TPU를 사용했다고 합니다. 또한 SGD나 Momentum optimizer가 아닌, 큰 크기의 batch size로 학습할 때 적절하다고 알려진 LARS optimizer를 이용하여 multi-device(분산학습)으로 학습하였습니다.
Batch normalization을 적용할 때는, device 별로 평균과 표준 편차를 계산하여 적용하는 것이 아니라, 모든 device에서의 평균/표준편차 값들을 통합해서 적용하였습니다. 이렇게 하면 positive sample이 포함된 device와 negative sample만으로 구성된 device 간의 분포를 같게 normalize 하게 되어 batch normalization 과정에서 발생하는 정보 손실을 최소화할 수 있습니다.
논문에서는 위에서 제안한 contrastive learning 기반의 framework로 다양한 실험을 진행했습니다. 기본적인 unsupervised learning 과정은 모두 ImageNet ILSVRC-2012 데이터셋으로 진행하였고, 학습한 encoder를 고정(freeze)시키고 그 위에 linear classifier를 얹어서 정확도를 측정하는 linear evaluation 방식으로 모델을 평가하였습니다. 그 외에 encoder를 고정시키지 않고 학습 가능하게 만들어서 평가하는 fine-tuning 방식이나, 다른 dataset을 이용해서 모델 변수를 조정하는 transfer learning 방식으로 SimCLR encoder를 평가하였습니다.
N 개의 데이터를 random sample 하여 mini-batch를 이루고 i 번째 sample에 대해 2(N - 1) 개의 augmented sample로 negative sample을 구성합니다. 별도로 negative sample을 만들어주지 않습니다.
$ sim(u,v) = u^{T}v / \left|u\right|\left|v\right| $ (이 때 u,v는 l2 normalized vector)는 cosine similarity를 의미합니다.)
위의 식에서 로그 안의 항을 분리해주면 , 분모의 식은 positive sample 간의 유사도 ⋯(1), 분모의 식은 나머지 negative sample 과의 유사도의 합 ⋯(2) 를 의미합니다.
즉 loss function은 , positive sample 간의 유사도는 크게, negative sample 간의 유사도는 작게 해주는 효과가 생깁니다.
Data Augmentation for Contrastive Representation Learning
SimCLR에서는 위의 그림에서 보이듯 cropping이나 resizing, rotating, cutout 등 이미지의 공간적/기하학적 구조를 변형하는 data augmentation 방법과 color dropping, jittering, Gaussian blurring, Sobel filtering 등 이미지의 색상을 왜곡하는 data augmentation 방법들을 제시하였습니다.
사실 ImageNet 데이터셋의 이미지들은 서로 다른 크기를 가지고 있기 때문에 학습 전에 항상 crop/resize 과정을 거쳐서 변환해주었다고 하는데요. SimCLR에서는 crop/resize 과정을 기본으로 하고, 한쪽 augmentation branch에서는 테스트해보고자 하는 다른 augmentation 방법들을 추가해주고 다른 한쪽 branch는 그대로 둔 채 학습을 진행하여 성능을 비교했습니다. 이러한 비대칭적인 구성은 다른 branch에도 augmentation 과정을 추가했을 때보다 성능이 낮을 수 있는데요, 그럼에도 불구하고 공정한 비교를 위해 이러한 방식을 선택했다고 합니다.
총 7가지의 data augmentation 방법을 하나 또는 두개 이어붙여서 성능을 측정하였는데요, 결과적으로 하나의 augmentation만으로는 좋은 성능을 달성하기 어려웠고, 여러 augmentation을 더해주었을 때 predictive task의 난이도가 높아지면서 representation quality가 증가했다고 합니다. 두 가지 augmentation을 이어 붙인 경우에는 위의 그림에서 알 수 있듯 random crop과 random color distortion을 이어붙인 경우에 가장 좋은 성능을 보여주었습니다.
특히 논문에서는 Color distortion이 꼭 필요한 이유에 대해서도 나름의 분석을 보여주었는데요. Color distortion 없이 random crop만 진행한 경우에는 augmentation branch를 통해서 얻은 sample들이 위의 historgram에서 보이듯 서로 같은 color distribution을 공유하고 있었고, 결국 네트워크가 시각적인 특징을 찾아내는 것이 아닌 색 배합만을 찾아내어 낮은 representation quality를 보여주었습니다.
또한 data augmentation의 세기를 바꾸어가며 모델의 성능을 측정해보기도 하였는데요, 위의 표를 보시면 color distortion을 강하게 가할수록 contrastive prediction task의 난이도가 증가하여 visual representation을 더 잘 추출하게끔 학습하였습니다. 심지어는 supervised learning에 도움이 되지 않는 강도의 augmentation도 SimCLR에서는 성능 향상에 기여하는 것을 볼 수 있습니다.
Model Architecture
위의 그림은 SimCLR과 supervised learning의 학습 방법을 다양한 크기의 모델에 적용시키며 linear evaluation 성능을 통해 비교한 것입니다. Supervised learning과 마찬가지로 SimCLR도 모델의 크기가 커질수록 학습 성능이 증가하는 경향을 보여주었습니다.
또한 non-linear projection head를 통해서 contrastive loss를 계산하는 구조 역시 linear projection head나 projection head를 아예 이용하지 않을 때보다 항상 좋은 성능을 낸다는 것을 보여주었습니다. 이 때 projection head의 output dimension은 성능에 크게 영향을 주지 않는 것이 확인되었습니다.
Loss Function and Batch Size
SimCLR에서는 cross-entropy 기반의 NT-Xent loss function을 이용하여 contrastive learning을 진행합니다. 논문에서는 NT-Xent loss와 기존 Contrastive learning에서 많이 사용되는 NT-Logistic, Margin triplet loss를 비교하면서 loss function 선정의 정당성을 보여주었습니다.
NT-Xent loss는 cross entropy loss를 기반으로 하기 때문에 negative sample들이 기준 sample과 얼마나 다른지에 대한 크기를 반영하고 있고, 결과적으로 좋은 성능을 보여준다고 합니다.
Contrastive learning은 안정적인 학습을 위해 충분한 양의 negative sample이 필수적입니다. Negative sample의 개수는 batch size와 비례하기 때문에, SimCLR을 학습할 때 batch size를 키울수록 모델의 성능이 증가하는 경향을 보여주었습니다. 또한 학습 과정에 random augmentation이 포함되어 있기 때문에 학습 시간이 길어질수록 충분한 양의 negative sample을 볼 수 있고, 성능에 대한 유의미한 경향성을 찾을 수 있었습니다.
Test
SimCLR은 1. 학습된 모델을 고정(freeze)하고 위에 linear classifier를 얹어서 성능을 평가하는 linear evaluation, 2.학습된 모델과 linear classifier를 모두 learnable한 상태로 학습하는 fine-tuning, 3.학습된 모델을 다른 종류의 dataset에 대하여 learnable한 상태로 학습하는 transfer learning 의 세 가지 방법으로 평가하였습니다.
우선 기존의 self-supervised 방법들과 비교했을 때 SOTA의 성능을 보여주었습니다. 위의 두 그림은 Linear evaluation과 적은 dataset에 대한 fine-tuning 평가의 결과입니다. 두 방법들 모두 좋은 성능을 보여주었고, fine-tuning의 경우에는 같은 모델의 supervised learning 학습 결과보다도 좋은 성능을 보여주었습니다. (단, SimCLR의 경우 이미 pre-training된 모델을 fine-tuning 한 것이고 supervised learning의 경우 scratch부터 학습했기 때문에 학습량 차원에서는 공정한 비교가 아닙니다.)
정리
Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis 논문을 읽어보려다가 self-supervised learning에 대해서 공부를 해야되겠다는 생각이 들어서 본 논문을 읽게되었습니다. SimCLR v2, MOCO v1 v2와 같은 논문도 빠르게 훑고 넘어가려고 합니다. 또한 kaggle에서 관련 코드들을 연습하고 포스팅하도록 하겠습니다.