About me
home
Portfolio
home
🥍

Reverse Distillation 정리

기존 knowledge distillation 기반의 이상 탐지 모델은 이상 데이터에 대해 teacher와 student 모델의 출력 특성 간의 차이를 최대화하는 것이 중요 과제이다. 그러나 기존 knowledge distillation 구조에는 두 가지 특징이 있다. 첫 번째로 teacher와 student 모델의 네트워크 구조가 유사하며, 두 번째로 input이 teacher와 student에 동일하게 입력되어 데이터 흐름이 동일하다는 점이다. 이러한 구조로 인해 이상 데이터가 들어왔을 때에도 두 모델이 특성을 유사하게 추출하게 되어 이상 탐지를 원활히 하지 못하는 문제점이 생긴다.
이러한 한계점에 대한 해결책으로 역순(reverse)된 구조를 적용하여 새로운 방식을 제시한 Reverse Distillation 모델이 등장하였다.
이 모델은 STPM 모델의 주요 특징을 유지한 채 teacher 모델을 encoder 구조로, student 모델을 decoder 구조로 변형하였다. 또한 input 이미지의 data flow를 역전시킨 아키텍처를 제시한다.

아키텍처 구조

각각 3개의 층으로 구성된 teacher encoder와 student decoder를 확인할 수 있다. 각 layer의 encoder와 decoder의 특성은 손실함수의 입력으로 들어가게 된다. bottleneck 단계에서는 Multi-Feature Fusion Block(MFF)One-Class Embedding (OCE)블록이 순차적으로 배치되어 구성된다.

Teacher encoder

ImageNet 데이터셋으로 사전 학습된 WideResNet50 모델을 사용하며, 3개의 레이어로 구성된 Encoder 구조를 가진다. 이 구조에서 데이터 흐름에 따라 down-scaling이 된다.

Student decoder

teacher encoder와는 다른 구조를 가지며, 데이터 흐름이 하위 층부터 역순으로 진행된다. Teacher 모델과 유사한 구조를 사용하지만, 사전 학습이 되지 않은 상태로 MVTec AD 데이터셋의 정상 데이터만을 사용하여 학습한다. 이는 정상 특성만을 학습하고자 하는 목적으로 설정되었다. 이러한 방식으로 인해 추론 단계에서 이상 데이터가 입력으로 주어지면, student 모델은 이상 데이터에 대한 특성을 정상적으로 추출해내지 못하게 된다.

MFF(Multi-Feature Fusion Block)

teacher 모델의 최종 출력을 전처리하여 student 모델의 입력으로 사용한다. Teacher encoder의 상위 층에서는 global semantic, structural information과 같은 high-level 특성을 추출하며, 하위 층에서는 색상, 가장자리, 질감과 같은 low-level 특성을 추출한다. Teacher encoder에서 생성된 low-level과 high-level 특성을 student decoder에 모두 전달하기 위해, 각 층에서의 출력은 합성곱 연산을 통해 크기가 조정된 후 연결(concatenate)된다. 이를 통해 지식 전달 과정이 효율적으로 이루어진다.

OCE(One-Class Embedding)

MFF 블록의 출력이 고차원 공간에 분포하는 문제를 해결하기 위해 고차원 특성을 저차원 공간으로 투영한다. 이 과정을 통해 정상 데이터의 표현 특성이 낮은 차원의 정보로 압축되며, 정상 데이터의 특성만이 집약된 지식을 student 모델에 전달한다. 이 방법은 이상 데이터의 특성이 student 모델로 전달되는 것을 방지하여 모델 성능을 향상시키는 효과를 제공한다.

Loss function

코사인 유사도(Cosine similarity)를 손실 함수로 사용한다. 코사인 유사도는 두 벡터 간의 각도를 기반으로 유사성을 측정하는 방법으로, -1에서 1 사이의 값을 가진다. 각 층에서 추출된 teacher와 student 모델 간의 특성 맵(feature map)을 vector-wise로 손실 함수에 입력하므로, 픽셀 단위의 유사도가 계산된다. 유사도가 높을 경우 손실 값은 0에 가까워지고, 유사도가 낮을 경우 손실 값은 1에 가까워진다. 학습 과정에서는 역전파를 통해 student 모델의 가중치가 teacher 모델이 정상 데이터에 대해 추출한 특성과 유사하도록 학습된다.