굴러가는 분석가의 일상

[개념정리] Diffusion Model Loss Function 본문

Generative AI

[개념정리] Diffusion Model Loss Function

G3LU 2024. 6. 15. 23:17

✔️손실함수(Loss Function) 

역방향 과정의 손실함수는 Negative Log-Likelihood (\(-\log \left(p_{\theta }\left(x_{0}\right)\right)\)) 으로 정의할 수 있다. 하지만 이를 직접 최적화하려면, 각 시간 단계 t에서의 모든 상태 \(x_{t}\)를 추적하고 계산해야 하기 때문에 메모리와 계산 자원 측면에서 매우 비효율적이다. 이러한 문제점을 해결하기 위해 변분 추론(Variational Lower Bound)을 사용한다.

Variational Lower Bound

 

직접적으로 Negative Log-Likelihood 사용하는 대신, 최적화할 수 있는 비음수 값(KL Divergence)을 더하게 된다. 그럼 아래 수식의 오른쪽 항은 항상 왼쪽 항보다 크거나 같게 된다. 오른쪽 항을 최소화함으로써 negative log-likelihood를 최소화할 수 있게 된다. 그럼 최종적으로 계산이 불가능한 손실 함수에 대해 추론이 가능해진다.

 

위의 Variational Lower Bound를 확장하면 세 가지의 항목으로 나뉘게 된다. 

 

1. \(L_{T}\): Constant Term 

 

q는 학습 가능한 파라미터가 없고 p는 단순한 가우시안 노이즈 확률이기 때문에 이 항은 학습 중에 상수로 간주되어 고려하지않아도 됩니다.

 

2. \(L_{T-1}\): Stepwise Denoising Term  

 

현재 상태( \(x_{t}\) )가 주어질 때, 이전 상태 \(x_{t-1}\) 가 나올 확률 분포 q와 p의 KL Divergence를 최소화하여, p와 q 확률분포 차이를 줄일 수 있습니다. 확률 분포 q와 p는 이와 같은 수식으로  \(q\left(x_{t-1}|x_{t},x_{0}\right)\)과 \(p_{\theta }\left(x_{t-1}|x_{t}\right)\)표현될 수 있습니다.

 

두 개의 확률분포는 가우시안 정규분포를 따르며, 이를 밀도 함수(PDF)를 전개하면 아래와 같이 작성할 수 있게 됩니다. 

앞전에 Diffusion Model은 마르코프 특성이 기반이 된다고 하였습니다. 이에 현재 상태 \(x_{t}\), 이전 상태 \(x_{t-1}\) 및 초기 상태 \(x_{0}\) 간의 관계를 나타낼 수 있는 복합적인 수식이 필요합니다. 이러한 수식은 모델의 각 시점에서의 상태를 조금 더 명확하게 예측할 수 있습니다. 또한, 평균과 분산을 구하기 위해서는 왼쪽/오른쪽 항을 가우시안 분포로 rearrange 해야하며, 결과적으로는 아래와 같습니다. 

NOTE : 위 수식에서 오른쪽 항에 대해 자세하게 아는 것이 이를 이해하는데 큰 도움이 된다고 생각하여 작성하게 되었으니, 참고하시는 것이 좋을겁니다!!! 

첫번째 항 : 현재 상태 \(x_{t}\)와  이전 상태 \(x_{t-1}\) 간의 관계를 표현하는 수식입니다. 이는 가우시안 정규분포를 따르며, 이에 대한 수식은 \(N(x_{t};\sqrt{a_{t}}x_{t-1},\beta _{t}I\) 입니다. 이에 대해서 \(x_{t-1}\)의 실제 평균을 추정할 수 있습니다. 실제 평균 \(\mu \left(x_{t},x_{0}\right)=\frac{1}{\sqrt{a_{t}}}\left(x_{t}-\frac{1-a_{t}}{\sqrt{1-a_{t}}}\epsilon \right)\)으로 정의할 수 있습니다. 
 
두번째 항 : 이전 상태 \(x_{t-1}\)와  초기 상태 \(x_{0}\) 간의 관계를 표현하는 수식입니다. 이 또한 가우시안 정규분포를 따르며, 이에 대한 수식은 \(N\left(x_{t-1};\sqrt{a_{t-1}}x_{0},\left(1-\alpha _{t-1}\right)I\right)\) 입니다. 이를 통해 모델이 예측한 평균을 추정할 수 있습니다. 모델이 예측한 평균 \(\mu _{\theta }\left(x_{t},t\right)=\frac{1}{\sqrt{a_{t}}}\left(x_{t}-\frac{1-a_{t}}{\sqrt{1-a_{t}}}\epsilon _{\theta }(x_{t,t)}\right)\)으로 정의할 수 있습니다. 

세번째 항 : 현재 상태 \(x_{t}\)와  초기 상태 \(x_{0}\) 간의 관계를 표현하는 수식입니다. 이는 이전 상태 \(x_{t-1}\)과 독립적이기 때문에 고려하지 않아도 됩니다. 

 

위의 수식들은 가우시안 분포를 따르고 있으며, 유일한 매개변수는 평균이기 때문에 KL Divergence을 사용하면 아래와 같이 Loss를 도출할 수 있습니다. 

 

Loss는 단순하게 실제 평균 값과 모델이 예측한 값의 MSE(Mean Squared Error)이며, 가중치 항을 무시하는 것이 더 좋은 결과를 얻을 수 있다고 합니다.

 

3. \(L_{0}\): Constant Term 

 

다음은 마지막 잡음 제거 단계의 재구성 손실에 관한 내용이며, 아래와 같은 이유로 훈련 시 크게 신경 쓰지 않아도 된다. 

  • Lₜ₋₁에서 동일한 신경망을 사용하여 근사할 수 있다.
  • 이를 무시하면 샘플의 품질이 좋아지고 구현이 더 간단해진다.

 

'Generative AI' 카테고리의 다른 글

[개념 정리] Diffusion Model  (0) 2024.06.01