Data Efficient Image Transformer
등장배경
ViT가 당시 SOTA를 달성했지만, 좋은 성능을 내려면 매우 큰 데이터셋인 JFT-300M 사용하여 학습시켜야만 했다.
따라서 ViT 모델을 그대로 사용하지만 Knowledge Distillation 개념을 추가한 DeiT 모델이 등장하게 된다. 따라서 DeiT는 오로지 ImageNet 데이터셋만으로 좋은 성능이 나온다!
•
DeiT는 Hard Distillation을 사용하여 적은 데이터 만으로 빠르게 성능을 수렴시켰다.
•
Distillation 토큰을 추가 → 출력에서 Distillation Embedding 이 나오게 되는데,
CLS 토큰에서 나온 Class Embedding과 Distillation Embedding을 융합하여 class를 예측한다.
•
Knowledge Distillation을 할 때, Teacher 모델을 Transformer로 하는 것보다 CNN 기반 모델을 사용하는 것이 더 좋은 성능을 냈다.(CNN은 Transformer보다 inductive bias가 높기 때문..)
•
ImageNet만을 학습한 DeiT 모델은 CIFAR-10, CIFAR-100, Oxford-102 flowers와 같은 Downstream Task들에 대해서도 경쟁력있는 성능을 가지므로 Generalization이 잘 이루어졌다.
Soft Distillation
Hard Distillation
Distillation Token
아래는 Student 모델인 DeiT의 아키텍처이다.
DeiT에서는 ViT에서의 CLS 토큰과 똑같은 방식으로 distillation 토큰을 추가적으로 붙였다.
•
CLS token과 distillation token은 Transformer 내에서 Attention을 통해 상호작용된다.
•
class token이 들어가 마지막 layer에서 classification을 위한 embedding dim이 나온다.
•
distillation token이 들어가 마지막 layer에서 Teacher model의 output과의 CE를 통해 학습을 하도록 하는 embedding dim이 나온다.
•
학습할때는 위와 같고, 추론할 때는 두 토큰으로부터 나온 embedding을 add하여 predict한다.
실험결과
아래 그래프는 Accuracy와 speed의 Trade-off 그래프이다.
•
밑에 파란점의 ViT 모델은 ImageNet 데이터만으로 학습했을 때의 ViT 모델이다.
이런 경우에는 오히려 EfficientNet보다도 떨어지는 성능을 보인다.
증류기 기호는 Distillation token을 사용한 모델임을 의미한다.
Teacher 모델을 DeiT-B 자신으로 사용한 경우와 CNN 기반의 모델(RegNet)을 사용했을 때를 비교해보았더니, CNN 기반의 모델을 Teacher로 했을 경우 더 높은 accuracy를 보이고 있다.
CNN 기반 모델은 inductive bias(translation equivariance, locality)가 있으므로, 굳이 많은 데이터셋이 필요 없게 된다.