About me
home
Portfolio
home
🫖

Knowledge Distillation 정리

정의: 큰 네트워크(Teacher 모델)로부터 작은 하나의 네트워크(Student 모델)에 지식을 전달하는 기법. 즉 복잡한 모델이 학습한 generalization 능력을 단순한 모델에게 transfer 해주는 것이다.
* Transfer learning와는 달리 같은 도메인 내의 지식을 전달하는 방식이다.
의의: 무거운 모델의 성능을 작은 모델이 낼 수 있다면 Computing resource(GPU, CPU) ,배터리, 메모리 측면에서 더 효율적이다. 모델을 실제 서비스로 배포할때도 필요성을 찾을 수 있다. 딥러닝 모델이 배포되는 모바일 장치는 복잡한 모델을 작동시킬만한 강력한 하드웨어가 아니다.

Loss function은 student lossdistillation loss가 있다.

α 는 두 loss에 대한 가중치이다. 무슨 Loss를 더 중요하게 학습할 건지 판단하는 파라미터이다.
TT는 Temperature로 극단적으로 Softmax 값이 차이나게 되는 것을 막아주는 파라미터이다.
Hard label을 하게 되면 일부 class에 대한 확률값은 0에 가까운데, 이는 학습 시에 지식이 잘 전달되지 않게 된다. → T를 통해 Soft label으로 만들어 준다.
T가 높으면 → 비슷비슷하게, T가 작으면 → 하나만 뚜렷하게 된다. (T는 보통 2~5 사이의 값을 사용한다고 한다.)
또한 T는 output을 더 soft하게 만들어주기 때문에 overfitting을 방지하는 효과도 있다. 모델의 parameter가 많을수록 overfitting될 가능성이 높은데 T가 그걸 막아준다.

두 가지 Loss 설명

전체 학습 과정

1.
Teacher Network 학습
training set(input, hard label)을 이용해 large model을 학습한다. (사전에 미리 해놓는다.)
이후에 T를 이용하여 soft labels를 생성한다.
2.
Student Network 학습
input과 앞서 구한 soft labeltransfer set을 생성한다.
transfer set은 두가지 prediction에 쓰이고 small model을 학습한다.
아까와 똑같은 T를 사용하여 soft predictions 도출
이번엔 T 없이 hard predictions 도출
3.
distillation loss 구성
soft labelssoft predictions로 KL Divergence를 통해 구한다.
4.
student loss 구성
hard prediction과 hard label로 Cross-Entropy를 통해 구한다.
5.
두 loss를 더한 최종 loss를 통해 모델 갱신

Distillation 방식 종류