NeurIPS 2024 (Oral). [Paper] [Page] [Github] [Huggingface]
Tianwei Yin, Michaël Gharbi, Taesung Park, Richard Zhang, Eli Shechtman, Fredo Durand, William T. Freeman
MIT | Adobe Research
23 May 2024

Introduction

Diffusion model의 샘플링 절차에는 일반적으로 수십 개의 반복적인 denoising step이 필요하며, 각각은 신경망을 통한 forward pass이다. 이로 인해 고해상도 text-to-image 합성이 느리고 비용이 많이 든다. 이 문제를 해결하기 위해 teacher diffusion model을 효율적인 few-step student model로 변환하기 위한 수많은 distillation 방법이 개발되었다. 그러나 student model은 일반적으로 teacher model의 noise-이미지 매핑을 학습하기 위해 학습되지만 teacher model의 동작을 완벽하게 모방하는 데 어려움을 겪기 때문에 종종 품질이 저하된다.

그럼에도 불구하고 GAN loss나 DMD loss와 같이 분포를 일치시키는 것을 목표로 하는 loss function은 noise에서 이미지로 가는 특정 경로를 정확하게 학습하는 데 따른 복잡도에 시달리지 않는다. 그 이유는 그 목표가 분포 면에서 teacher model과 일치시키는 것이기 때문이다. 즉, student와 teacher 출력 분포 간의 JS divergence 또는 대략적인 KL divergence를 최소화하는 것이다.

특히, DMD는 Stable Diffusion 1.5를 distillation하는 데 있어 SOTA 결과를 보여주었지만, GAN loss보다 조사가 덜 이루어졌다. 그럴 만한 이유 중 하나는 DMD가 안정적인 학습을 보장하기 위해 여전히 추가적인 regression loss가 필요하기 때문이다. Regression loss를 사용하려면 teacher model의 전체 샘플링 step을 실행하여 수백만 개의 noise-이미지 쌍을 만들어야 하는데, 이는 비용이 많이 든다. 또한 regression loss는 DMD의 분포 매칭 loss의 주요 이점을 무효화하는데, 이는 student의 품질이 teacher의 품질보다 좋아질 수 없기 때문이다.

본 논문에서는 학습 안정성을 손상시키지 않고 DMD의 regression loss를 없애는 방법을 보여준다. 그런 다음 GAN 프레임워크를 DMD에 통합하여 분포 매칭의 한계를 넓히고, ‘backward simulation’이라는 새로운 학습 절차를 사용하여 few-step 샘플링을 가능하게 한다. 이러한 모든 것을 종합하면, 4-step 샘플링을 사용하여 teacher보다 성능이 뛰어난 SOTA 생성 모델이 탄생한다. DMD2라고 하는 본 논문의 방법은 one-step 이미지 생성에서 SOTA 결과를 달성하였으며, SDXL를 distillation하여 고품질 메가픽셀 이미지를 생성함으로써 확장성(scalability)을 보여주었다.

Background: Distribution Matching Distillation

Distribution Matching Distillation (DMD)은 타겟 분포 \(p_{\textrm{real},t}\)와 generator 출력 분포 \(p_{\textrm{fake},t}\) 사이의 대략적인 KL divergence를 최소화하여 여러 step의 diffusion model을 one-step generator $G$로 distillation한다. DMD는 gradient descent로 $G$를 학습시키므로 이 loss의 gradient만 필요하며, 이는 2개의 score function의 차이로 계산할 수 있다.

\[\begin{equation} \nabla \mathcal{L}_\textrm{DMD} = \mathbb{E}_t (\nabla_\theta \textrm{KL}(p_{\textrm{fake},t} \| p_{\textrm{real},t})) = - \mathbb{E}_t \left( \int (s_\textrm{real} (F(G_\theta (z), t), t)) \frac{d G_\theta (z)}{d \theta} dz \right) \\ \textrm{where} \quad s (x_t, t) = \nabla_{x_t} \log p_t (x_t) = - \frac{x_t - \alpha_t \mu (x_t, t)}{\sigma_t^2} \end{equation}\]

($z \sim \mathcal{N}(0,I)$는 Gaussian noise, $\theta$는 generator의 파라미터, $F$는 forward diffusion process)

DMD는 teacher \(\mu_\textrm{real}\)로 고정된 사전 학습된 diffusion model을 사용하고, $G$의 샘플에 대한 denoising score-matching loss를 사용하여 student \(\mu_\textrm{fake}\)를 동적으로 업데이트한다.

DMD 논문에서는 \(\nabla \mathcal{L}_\textrm{DMD}\)를 정규화하고 고품질 모델을 위해 추가 regression 항이 필요하다는 것을 발견했다. 이를 위해 noise map $z$에서 시작하여 teacher diffusion model과 deterministic sample를 사용하여 생성된 이미지 $y$로 구성된 noise-이미지 쌍 $(z, y)$의 데이터셋을 수집하였다. 입력 noise $z$가 주어지면 regression loss는 generator 출력을 teacher의 예측과 비교한다.

\[\begin{equation} \mathcal{L}_\textrm{reg} = \mathbb{E}_{(z, y)} d(G_\theta (z), y) \end{equation}\]

($d$는 LPIPS와 같은 distance function)

대규모 text-to-image 모델이나 복잡한 조건이 있는 모델의 경우, 이 데이터를 수집하는 것이 상당히 어렵다. 예를 들어 SDXL에 대한 noise-이미지 쌍 하나를 생성하는 데는 약 5초가 걸리며 LAION 6.0 데이터셋의 1,200만 개 프롬프트를 처리하는 데 약 700~100일이 걸린다. 이 데이터셋 구성 비용만 해도 전체 학습 컴퓨팅 비용의 4배가 넘는다. 이 regression loss는 teacher의 샘플링 경로를 고수하도록 장려하기 때문에 분포에서 student와 teacher를 일치시키는 DMD의 목표와도 상충된다.

Improved Distribution Matching Distillation

1. Removing the regression loss

DMD에서 사용된 regression loss는 mode coverage와 학습 안정성을 보장하지만, 대규모 distillation을 번거롭게 만들고 분포 매칭 아이디어와 상충되어 distillation된 generator의 성능을 teacher model의 성능으로 제한한다. 본 논문의 첫 번째 개선 사항은 이 loss를 제거하는 것이다.

2. Stabilizing pure distribution matching with a Two Time-scale Update Rule

단순하게 DMD에서 regression loss를 생략하면 학습 불안정성이 발생하고 품질이 크게 떨어진다. 이 불안정성은 가짜 diffusion model \(\mu_\textrm{fake}\)의 근사 오차에 기인한다. 이 모델은 generator의 비정상 출력 분포에 동적으로 최적화되기 때문에 fake score \(s_\textrm{fake}\)를 정확하게 추적하지 못한다. 이로 인해 근사 오차와 편향된 gradient가 발생한다.

저자들은 two time-scale update rule (TTUR)을 사용하여 이를 해결하였다. 즉, \(\mu_\textrm{fake}\)와 generator $G$를 서로 다른 주기로 학습시켜 \(\mu_\textrm{fake}\)가 generator의 출력 분포를 정확하게 추적하도록 한다. 구체적으로, regression loss 없이 generator 업데이트당 5번의 \(s_\textrm{fake}\) 업데이트를 사용하면 안정성이 우수하며 ImageNet에서의 원래 DMD 품질과 일치하면서 훨씬 빠른 수렴을 달성할 수 있다.

3. Surpassing the teacher model using a GAN loss and real data

TTUR을 통해 비용이 많이 드는 데이터셋 구축 없이도 DMD와 비슷한 학습 안정성과 성능을 달성했지만, 여전히 distillation된 generator와 teacher diffusion model 사이에는 성능 격차가 남아 있다. 저자들은 이 격차가 DMD에서 사용된 real score function \(\mu_\textrm{real}\)의 근사 오차에 기인할 수 있다고 가정하였다. 이 오차는 generator로 전파되어 최적이 아닌 결과로 이어질 것이다. DMD의 distillation된 모델은 실제 데이터로 학습되지 않기 때문에 이러한 오차를 복구할 수 없다.

저자들은 파이프라인에 추가로 GAN loss를 통합하여 이 문제를 해결하였다. 여기서 discriminator는 실제 이미지와 generator가 생성한 이미지를 구별하도록 학습된다. 실제 데이터를 사용하여 학습된 GAN classifier는 teacher의 한계에 시달리지 않아 student가 teacher의 샘플 품질에서 능가할 가능성이 있다. DMD에 GAN classifier를 통합하기 위해 fake denoiser의 bottleneck 위에 classification branch를 추가한다. UNet의 classification branch와 인코더 feature는 표준 GAN loss를 최대화하여 학습된다.

\[\begin{equation} \mathcal{L}_\textrm{GAN} = \mathbb{E}_{x \sim p_\textrm{real}, t \sim [0,T]} [\log D(F(x, t))] + \mathbb{E}_{z \sim p_\textrm{noise}, t \sim [0,T]} [-\log (D(F(G_\theta (z), t)))] \end{equation}\]

($D$는 discriminator, $F$는 forward diffusion process)

Generator $G$는 이 loss를 최소화한다. 이 GAN loss는 쌍 데이터가 필요하지 않고 teacher의 샘플링 궤적과 독립적이므로 분포 매칭 철학과 더 일관성이 있다.

4. Multi-step generator

제안된 개선 사항들을 통해 ImageNet과 COCO에서는 teacher diffusion model의 성능과 일치시킬 수 있다. 그러나 SDXL과 같은 대규모 모델은 noise에서 디테일한 이미지로의 직접 매핑을 학습하기 위한 복잡한 최적화 환경과 모델 용량으로 인해 one-step generator로 정제하기 어렵다. 이는 DMD를 확장하여 multi-step 샘플링을 지원하도록 동기를 부여했다.

$N$개의 timestep \(\{t_1, \ldots, t_N\}\)으로 구성된 schedule을 학습과 inference 모두에 동일하게 사용한다. Consistency model을 따라, inference하는 동안 각 step에서 noise 제거와 noise 주입 단계를 번갈아 사용하여 샘플 품질을 개선한다. 구체적으로, Gaussian noise $z_0 \sim \mathcal{N}(0,I)$에서 시작하여, 최종 이미지 \(\hat{x}_{t_N}\)을 얻을 때까지 denoising 업데이트와 forward diffusion step을 번갈아 가며 수행한다.

\[\begin{equation} \hat{x}_{t_i} = G_\theta (x_{t_i}, t_i) \\ x_{t_{i+1}} = \alpha_{t_{i+1}} \hat{x}_{t_i} + \sigma_{t_{i+1}} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \end{equation}\]

1000개의 step으로 학습된 teacher model의 경우, 4-step 모델은 999, 749, 499, 249의 timestep을 사용한다.

5. Multi-step generator simulation to avoid training/inference mismatch


이전 multi-step generator는 일반적으로 noise가 있는 실제 이미지의 noise를 제거하도록 학습된다. 그러나 inference 중에 순수한 noise에서 시작하는 첫 번째 단계를 제외하고 generator의 입력은 이전 generator 샘플링 단계 \(\hat{x}_{t_i}\)에서 나온다. 이로 인해 품질에 부정적인 영향을 미치는 학습-inference 불일치가 발생한다.

저자들은 학습 중에 noisy한 실제 이미지를 현재 student generator가 몇 step 실행하여 생성한 noisy한 합성 이미지 \(x_{t_i}\)로 대체하여 이 문제를 해결하였다. 이는 inference 파이프라인과 유사하다. Teacher diffusion model과 달리 generator는 몇 step만 실행되기 때문에 처리하기 쉽다. 그런 다음 generator는 이러한 시뮬레이션된 이미지의 noise를 제거하고 제안된 loss function으로 학습된다. Noisy한 합성 이미지를 사용함으로써 불일치를 피하고 전반적인 성능을 향상시킬 수 있다.

Experiments

1. Class-conditional Image Generation

다음은 ImageNet-64×64에서의 이미지 생성 품질을 비교한 표이다.

2. Text-to-Image Synthesis

다음은 COCO 2014에서 SDXL backbone으로 이미지 생성 품질을 비교한 표이다.


다음은 user study 결과이다.


다음은 시각적으로 이미지 생성 품질을 비교한 결과이다.

3. Ablation Studies

다음은 (왼쪽) ImageNet과 (오른쪽) COCO 2014 (SDXL)에서의 ablation study 결과이다.


다음은 SDXL에 대한 ablation study 결과이다.

Limitations

  1. Student model은 teacher model에 비해 이미지 다양성이 약간 저하된다.
  2. 가장 큰 SDXL 모델의 품질과 일치하기 위해 여전히 4-step이 필요하다.