ICLR 2022. [Paper] [Github]
Emiel Hoogeboom, Alexey A. Gritsenko, Jasmijn Bastings, Ben Poole, Rianne van den Berg, Tim Salimans
Google Research
5 Oct 2021

Introduction

심층 생성 모델은 이미지, 텍스트, 오디오와 같은 다양한 데이터 소스를 모델링하는 데 큰 발전을 이루었다. Likelihood 기반 모델의 인기 있는 유형은 autoregressive model (ARM)이다. ARM은 확률 연쇄 법칙을 사용하여 조건부 분해로 고차원 결합 분포를 모델한다. 매우 효과적이지만 ARM은 데이터를 생성하기 위해 미리 지정된 순서가 필요하며, 이는 예를 들어 이미지와 같은 일부 데이터 형식에 대한 명확한 선택이 아닐 수 있다. 또한 단일 신경망 호출로 ARM의 likelihood를 검색할 수 있지만 모델에서 샘플링하려면 데이터의 차원과 동일한 수의 네트워크 호출이 필요하다.

최근 diffusion model은 새로운 학습 패러다임을 도입했다. 즉, 데이터 포인트의 전체 likelihood를 최적화하는 대신 likelihood 경계의 구성 요소를 샘플링하여 최적화할 수 있다. Discrete space에서의 diffusion에 대한 연구는 역 생성 과정이 카테고리 분포로 학습되는 이산적인 파괴 과정을 설명한다. 그러나 좋은 성능을 얻으려면 이러한 프로세스의 길이가 길어야 할 수 있으며, 이로 인해 discrete diffusion으로 likelihood를 샘플링하거나 평가하기 위해 많은 수의 네트워크 호출이 발생한다.

본 연구에서는 임의의 순서로 생성하는 방법을 배우는 ARM의 변형인 autoregressive diffusion model (ARDM)을 소개한다. ARDM은 순서에 구애받지 않는 ARM과 discrete diffusion model을 일반화한다. 우리는 ARDM이 몇 가지 이점이 있음을 보여준다. 표준 ARM과 달리 분포 파라미터를 예측하는 데 사용되는 신경망에 구조적 제약을 부과하지 않는다. 또한 ARDM은 동일한 성능을 얻기 위해 흡수 모델보다 훨씬 적은 step이 필요하다. 또한 diffusion model용으로 개발된 동적 프로그래밍 접근 방식을 사용하여 ARDM을 병렬화하여 성능을 크게 저하시키지 않고 여러 토큰을 동시에 생성할 수 있다. 저자들은 경험적으로 ARDM이 discrete diffusion model과 유사하거나 더 나은 성능을 보이는 동시에 모델링 단계에서 더 효율적임을 입증하였다.

Autoregressive Diffusion Models

ARDM은 임의의 순서로 변수를 하나씩 생성한다. 또한 ARDM은 픽셀의 비트 값과 같은 변수를 업스케일링할 수 있다. 표준 ARM과 달리 ARDM은 최신 diffusion model에서와 같이 목적 함수의 단일 step에서 학습된다. 또한 ARDM의 샘플링과 inference는 log-likelihood의 저하를 최소화하면서 동적 프로그래밍을 사용하여 병렬화할 수 있다.

Order Agnostic ARDMs

공학적 관점에서 ARM을 parameterize하는 주요 어려움은 삼각 또는 인과 관계 의존성을 적용해야 한다는 것이다. 특히 2D 신호의 경우 이 삼각 의존성은 임의의 순서에 적용하기 어렵고 멀티스케일 아키텍처에는 지루한 설계가 필요하다. 이 요구 사항을 완화하기 위해 최신 diffusion 기반 생성 모델에서 영감을 얻는다. 이러한 통찰력을 사용하여 한 번에 한 step에만 최적화된 목적 함수를 도출한다.

순서에 구애받지 않는 ARM에 대한 목적 함수는 $t$에 대한 합계를 적절하게 재가중된 기대값으로 대체하여 도출할 수 있다.

\[\begin{aligned} \log p(x) &\ge \mathbb{E}_{\sigma \sim \mathcal{U} (S_D)} \sum_{t=1}^D \log p(x_{\sigma (t)} \vert x_{\sigma (<t)}) \\ &= \mathbb{E}_{\sigma \sim \mathcal{U} (S_D)} D \cdot \mathbb{E}_{t \sim \mathcal{U} (1, \cdots, D)} \log p(x_{\sigma (t)} \vert x_{\sigma (<t)}) \\ &= D \cdot \mathbb{E}_{t \sim \mathcal{U} (1, \cdots, D)} \mathbb{E}_{\sigma \sim \mathcal{U} (S_D)} \frac{1}{D-t+1} \sum_{k \in \sigma (\ge t)} \log p(x_k \vert x_{\sigma (<t)}) \end{aligned}\]

간결하게 하한 기대값을 다음과 같이 쓸 수 있다.

\[\begin{equation} \log p(x) \ge \mathbb{E}_{t \sim \mathcal{U} (1, \cdots, D)} [D \cdot \mathcal{L}_t] \\ \textrm{where} \quad \mathcal{L}_t = \frac{1}{D-t+1} \mathbb{E}_{\sigma \sim \mathcal{U} (S_D)} \sum_{k \in \sigma (\ge t)} \log p(x_k \vert x_{\sigma (<t)}) \end{equation}\]

여기서 \(\mathcal{L}_t\) 항은는 step $t$에 대한 likelihood 성분을 나타낸다. 중요한 것은 데이터 포인트의 모든 \(\mathcal{L}_t\) 항을 동시에 최적화할 필요가 없다는 것이다. 대신, minibatch의 각 데이터 포인트에 대해 단일 \(\mathcal{L}_t\) 항이 최적화되며 여기서 $t$는 균일 분포에서 샘플링된다. 이 목적 함수는 원래 순서에 구애받지 않는 ARM을 학습하기 위해 제안되었다. 이러한 관점에서 시작하여 ARDM을 개발하고 순서에 구애받지 않는 ARM의 특수한 경우인 order agnostic ARDM (OA-ARDM)을 참조한다. 흥미롭게도 각 \(\mathcal{L}_t\) BERT와 같은 목적 함수로 볼 수 있으며, 여기서 정확히 $D − t + 1$개의 토큰이 마스킹되고 이후에 예측된다. 따라서 OA-ARDM은 재가중 항 $\frac{1}{D-t+1}$을 포함하는 loss 항 \(\mathcal{L}_t\)가 있는 $D$개의 BERT 컬렉션으로 학습된다. 또 다른 통찰력은 이 생성 프로세스가 모델이 흡수된 (또는 가려진) 변수를 생성하는 것을 목표로 하는 absorbing diffusion과 매우 유사하다는 것이다. 어떤 상황에서는 likelihood 항 대신 loss 항을 참조하기를 원할 수 있으므로 \(L_t = − \mathcal{L}_t\)로 정의한다.

Parametrization

모든 $\sigma$와 $t$에 대해 $k \in \sigma(\ge t)$에 대한 모델 분포 $\log p(x_k \vert x_{\sigma(< t)})$에 대한 parameterization을 원한다. 각각의 $\sigma$와 $t$에 대해 원칙적으로 완전히 새로운 신경망을 가질 수 있다. 그러나 이것은 $t$의 수가 $O(D)$만큼 커지고 $\sigma$의 수가 $O(D!)$만큼 커지므로 매우 불편할 것이다. 대신 단일 신경망이 활용되고 서로 다른 $\sigma$와 $t$에 대해 공유된다. 이는 입력에서 변수를 마스킹하고 출력에서 변수를 예측하여 구현된다.

정확히 말하면 \(x \in \mathcal{X} = \{1, \cdots, K\}^D\)는 각 차원에 대한 확률 벡터를 출력하는 신경망 $f : \mathcal{X} \rightarrow \mathbb{R}^{D \times K}$와 $K$개의 클래스가 있는 이산 변수를 나타낸다. 컨디셔닝은 마스킹을 통해 수행된다. 주어진 순열 배열 $\sigma$에 대해 boolean mask를 생성하는 element-wise 비교 $m = \sigma < t$를 계산한다. 그런 다음 $\theta = f(m \odot x)$를 예측하여 마스크를 사용한다. 여기서 $\odot$은 element-wise 곱셈이다. 각 위치 $k \in \sigma (\ge t)$에 대해 로그 확률 벡터 $\theta_k$가 사용된다. $\mathcal{C} (x_k \vert \theta_k)$가 클래스 확률이 $\theta_k$인 $x_k$에 대한 카테고리형 분포라 하면

\[\begin{equation} \log (x_k \vert x_\sigma(< t)) = \log \mathcal{C} (x_k \vert \theta_k) \end{equation}\]

를 모델링하도록 선택한다. 이러한 인덱스 $k \in \sigma (\ge t)$의 위치는 반대 마스크 $1 - m$을 사용하여 검색된다. 이 parameterization으로 ARDM을 샘플링하고 최적화하는 절차는 Algorithm 1과 2에 나와 있다.


학습 단계는 다음 그림에 시각화되어 있다.


Step당 단일 출력만 사용된 아래 그림과 달리 학습 단계에서는 마스킹된 모든 차원이 동시에 예측된다.


따라서 모델을 최적화하기에 충분한 신호가 있는지 예측할 여러 변수가 있다.

함수 $f$에 대한 입력은 데이터 modality에 따라 다를 수 있다. 이미지와 오디오의 경우 feature 정규화 후 값이 0이 되도록 마스크가 입력에 적용된다. 마스크 자체도 입력 표현으로 입력에 concat되어 모델이 값이 실제로 0인지 또는 값이 흡수 상태 0인지 여부를 식별할 수 있다. 언어의 경우 입력 표현이 증강되고 흡수된 값이 새 클래스 $K + 1$로 설정된다. 이 경우 마스크 자체를 모델에 대한 입력으로 제공할 필요가 없다. 더 일반적으로, 마스킹된 상태를 $x$와 같은 모양을 갖지만 미리 지정된 값만 포함하는 흡수 상태 벡터 $a$로 나타낼 수 있다. 그러면 네트워크에 대한 입력은 마스킹된 $m \odot x$가 아니라 \(m \odot x + (1 − m) \odot a\)이다. 또한 네트워크 $f$는 일반적으로 diffusion model에서 수행되는 것처럼 시간 성분 $t$를 입력으로 사용할 수도 있다.

요약하면 네트워크는 $\theta = f (i,m,t)$로 몇 가지 추가 입력을 받는다. 여기서 $i = m \odot x + (1 − m) \odot a$이고 $x$의 처리는 데이터 유형에 따라 다를 수 있다.

1. Parallelized ARDMs

Parametrization의 중요한 속성은 여러 변수에 대한 분포가 동시에 예측된다는 것이다. 이 parameterization을 활용하여 병렬 독립 변수 생성을 허용한다. 본질적으로, $x_{\sigma (< t)}$에 대해서만 컨디셔닝하면서 양수 $k$에 대해 $x_{\sigma (t+k)}$의 분포를 원한다. 먼저 미래 변수 예측과 likelihood 항 사이의 연결에 대해 관찰한다. $k = 1, \cdots, D-t$에 대하여 순열에 대한 기대값 때문에

\[\begin{equation} \mathbb{E}_\sigma [\log p(x_{\sigma (t+k)} \vert x_{\sigma (< t)})] = \mathbb{E}_\sigma [\log p(x_{\sigma (t)} \vert x_{\sigma (< t)})] = \mathcal{L}_t \end{equation}\]

이다. 즉, 모델이 예측하는 step $t + k$는 중요하지 않다. 이러한 step은 모두 동일한 연관 likelihood를 갖는다. 그 결과, $t$번째 변수에서 시작하여 독립적으로 $k$개의 토큰을 순서에 구애받지 않고 생성하면 단일 step에서 \(k \cdot \mathcal{L}_t\)의 로그 확률 기여도가 생기는 반면, 전통적인 접근 방식은 \(\sum_{i=1}^k \mathcal{L}_{t+i}\)의 비용으로 $k$ step을 수행한다. 이는 주어진 예산에서 어느 순간에 얼마나 많은 병렬 step을 수행해야 하는지 계산하는 동적 프로그래밍 알고리즘을 구성하기에 충분하다. 동적 프로그래밍은 일반적으로 최소화 관점에서 설명되기 때문에 비트 단위로 측정되는 loss 성분 \(L_t = -\mathcal{L}_t\)를 정의한다. Loss 측면에서 timestep $t$에서 $k$개의 변수를 생성하면 $k \cdot L_t$ 비트의 비용이 발생한다.

또한 transition cost matrix를 양의 정수 $k$에 대해 $L_{t,t+k} = k · L_t$로 정의하고 그렇지 않은 경우에는 $L_{t+k,t} = 0$으로 정의한다. 따라서 $L_{t,t+k}$는 모든 관련 $t$와 $k$에 대해 $t$번째 위치에서 시작하여 다음 $k$개의 변수를 병렬로 모델링하는 데 드는 비용을 정확하게 설명한다. 이 transition cost matrix를 사용하여 동적 프로그래밍 알고리즘을 활용하여 병렬화해야 하는 step을 찾을 수 있다.


예를 들어, 위 그림의 예에서 가상의 20 step 문제에는 5 step의 예산이 주어진다. 일반적으로 알고리즘은 $L_t$ 성분 간의 차이가 큰 영역에서 더 많은 step을 수행하고 $L_t$ 성분이 거의 동일한 영역에서는 더 적은 step을 수행한다. 다음과 같이 잘 보정된 모델의 경우처럼 ARDM을 병렬화하면 약간의 비용이 발생할 수 있다.

\[\begin{equation} \mathcal{L}_t = \mathbb{E}_\sigma [\log p(x_{\sigma (t+1)} \vert x_{\sigma (< t)})] \le \mathbb{E}_\sigma [\log p(x_{\sigma (t+1)} \vert x_{\sigma (< t+1)})] = \mathcal{L}_{t+1} \end{equation}\]

그러나 더 적은 step이 사용되기 때문에 더 빠른 생성을 위해 교환될 수 있다. 즉, loss 성분 $L_t$는 $t$에 대해 단조 감소하고 모델을 병렬화하면 비용이 발생하며 알고리즘은 이를 최소화하는 것을 목표로 한다. 이것은 실제로 관찰되는 모델이 잘 보정되었다는 가정 하에 있다.

2. Depth Upscaling ARDMs

OA-ARDM은 무작위 순서로 변수를 생성하는 방법을 배운다. 결과적으로 매우 상세한 정보 (ex. 이미지의 최하위 비트)에 대한 결정은 생성 프로세스에서 비교적 초기에 모델링된다. 대신 프로세스를 여러 stage로 구성할 수 있으며 각 stage에 대해 변수가 개선된다. 이 프로세스를 업스케일링이라고 한다. 예를 들어 전체 256개 카테고리형 변수를 한 번에 생성하는 대신 먼저 최상위 비트를 생성한 다음 중요도 순으로 후속 비트를 생성할 수 있다. 프로세스를 정의하려면 먼저 업스케일링의 반대 프로세스인 파괴적인 다운스케일링 프로세스를 상상하는 것이 도움이 된다.

데이터 변수가 데이터 값에서 일반적인 흡수 상태로 다운스케일링되는 방법을 정의하는 transition matrix $P^{(i)}$를 통해 맵을 정의할 수 있다. 단순화를 위해 단일 차원 변수를 가정한다. 흡수 상태를 one-hot벡터 $x^{(0)}$로 표시한다. Diffusion 관점에서 업스케일링은 각 변수가 최하위 비트를 0으로 만드러 감소하는 다운스케일링 파괴 프로세스를 보완한다.

$P^{(1)}, \cdots, P^{(S)}$를 다운스케일링 맵의 시퀀스로 정의하여, 임의의 카테고리형 one-hot 데이터 변수 $x^{(S)} \in {0, 1}^K$에 대해 $P^{(1)} \cdot \ldots \cdot P(S) \cdot x^{(S)} = x^{(0)}$를 만족하도록 한다. 즉, 모든 카테고리 $K$는 $S$개의 다운스케일링 맵 이후 공통 흡수 상태로 감소한다. 이제 특히

\[\begin{equation} p(x^{(S)} \vert x^{(S-1)}) \cdot \ldots \cdot p (x^{(2)} \vert x^{(1)}) p(x^{(1)}) \end{equation}\]

을 모델링하여 다운스케일링 맵의 역순을 학습하여 업스케일링 생성 프로세스를 정의한다. Transition matrix는 다음 규칙을 통해 다른 stage 변수 $x^{(i)}$ 간의 쉬운 transition을 할 수 있게 한다.

\[\begin{equation} x^{(s)} = P^{(s+1)} x^{(s+1)} = \bar{P}^{(s+1)} x^{(S)} \\ \textrm{where} \quad \bar{P}^{(s)} = P^{(s)} \cdot P^{(s+1)} \cdot \ldots \cdot P^{(S)} \end{equation}\]

행렬 $P^{(i+1)}$은 누적 행렬 곱셈으로 계산되며 데이터 포인트 $x^{(S)}$에서 대응되는 다운스케일링 변수 $x^{(i)}$로 직접 전환할 수 있다. 이는 모델이 데이터 포인트당 단일 특정 stage에 대해서만 최적화되는 학습 중에 특히 유용하다. 구현을 위해 $P^{(S+1)} = I$를 항등 행렬로 정의하는 것이 일반적으로 유용하므로 $s = S$일 때 위의 방정식도 유지된다. Upscale ARDM을 학습시키기 위해 Algorithm 2를 확장할 수 있다.

Timestep $t$를 샘플링하는 것 외에도 최적화할 stage $i \sim \mathcal{U}(1, \ldots, S)$를 샘플링한다. 이 특정 stage에 대해 ARDM은 stage 내의 순열 $\sigma$와 stage 내의 timestep $t$를 샘플링하여 $p(x^{(s)} \vert x^{(s-1)})$를 모델링한다. 모든 항 $p(x^{(s)} \vert x^{(s−1)})$는 OA-ARDM으로 모델링된 stage를 나타낸다. 이것은 ARDM의 흥미로운 속성을 강조한다. 모델에서 샘플링하는 데 최대 $D\cdot S$ step이 걸릴 수 있지만 학습 복잡성은 여러 stage를 모델링해도 변경되지 않는다. 결과적으로 학습 중에 계산 복잡성을 증가시키지 않고 임의의 수의 stage를 추가하여 실험할 수 있다.

Bit Upscaling

깊이 업스케일링은 비트 업스케일링을 예로 들어 가장 쉽게 설명할 수 있다. 픽셀 값이 \(\{0, \ldots, 255\}\)인 표준 이미지를 생성하는 task를 생각해보자. 그러면 $D$ 차원의 이미지는 one-hot 표기법으로 \(x^{(8)} \in \{0, 1\}^{D \times 256}\)으로 나타낼 수 있다. $i$개의 최하위 비트를 제거하는 함수

\[\begin{equation} \textrm{lsb}_s (k) = \lfloor k / 2^s \rfloor 2^s \end{equation}\]

에 의해 정의된 다운스케일링 프로세스를 상상해 보자. 이 함수를 통해 다음과 같은 전환 행렬을 정의할 수 있다.

\[\begin{equation} P_{l, k}^{(8 + 1 - s)} = \begin{cases} 1 & \quad \textrm{if} \; l = \textrm{lsb}_s (k), \; k \in \textrm{Im} (\textrm{lsb}_{s-1}) \\ 0 & \quad \textrm{otherwise} \end{cases} \end{equation}\]

여기서 \(\{P^{(s)}\}\)는 0부터 인덱싱된다. 이 경우 모든 \(k \in \{0, \ldots, 255\}\)에 대해 \(\textrm{lsb}_8 (k) = 0\)이기 때문에 모든 값을 흡수 상태 0으로 매핑하는 8 stage이다. 카테고리가 적은 문제에 대한 이러한 행렬의 시각화는 아래 그림과 같다.


깊이 업스케일링은 비트에 국한되지 않으며 실제로 보다 일반적인 공식은 branching factor $b$에 대한 다운스케일링 맵 $l = \lfloor k/b^s \rfloor \cdot b^s$에 의해 제공된다. $b$가 2로 설정되면 비트 업스케일링이다. $b$가 더 높은 값으로 설정되면 변수는 더 적은 stage (정확히는 $\lceil \log_b (K) \rceil$)에서 생성될 수 있다. 이를 통해 모델이 수행하는 stage 수와 각 모델링 step이 억제하는 복잡도 사이에 고유한 trade-off가 가능하다.

Parametrization of the Upscaling Distributions

이제 데이터 포인트 $x^{(S)}$가 $x^{(S−1)}, \ldots, x^{(1)}$과 흡수 상태 $x^{(0)}$로 축소되는 방법이 정의되었지만 분포 $p(x^{(s)} \vert x^{(s−1)})$를 parameterize하는 것이 즉시 명확하지 않다. 두 가지 방법을 사용하여 분포를 parameterize할 수 있다.

첫 번째는 direct parametrization이다. 위의 비트 업스케일링 모델의 예에서 $(s - 1)$번째 유효 비트가 주어지면 $s$번째 유효 비트를 모델링한다. Direct parametrization은 현재 stage와 관련된 분포 파라미터 출력만 필요하기 때문에 일반적으로 더 계산적으로 효율적이다. 이는 클래스 수가 많은 경우에 특히 유용하다 (ex. $2^{16}$개의 클래스가 있는 오디오). 그러나 어떤 클래스가 관련이 있고 모델링해야 하는지 정확히 파악하는 것은 다소 지루할 수 있다.

또는 D3PM의 parameterization과 유사한 data parametrization를 사용할 수 있다. D3PM과 중요한 차이점은 다운스케일링 행렬 $P^{(s)}$는 결정론적 맵을 나타내는 반면 D3PM의 것은 확률적 프로세스를 나타낸다는 것이다. 이 parameterization을 위해 네트워크 $f$는 다음 식을 통해 stage $s$에서 관련 확률로 변환된 데이터 $x^{(S)}$의 모양과 일치하는 확률 벡터 $\theta$를 출력한다.

\[\begin{equation} \theta^{(s)} = \frac{P^{(s) \top} x^{(s-1)} \odot \bar{P}^{(s+1)} \theta}{x^{(s-1) \top} \bar{P}^{(s)} \theta} \\ \textrm{where} \quad p(x^{(s)} \vert x^{(s-1)}) = \mathcal{C} (x^{(s)} \vert \theta^{(s)}) \end{equation}\]

이 parameterization의 장점은 transition matrix \(\{P^{(s)}\}\)만 정의하면 된다는 것이다. 결과적으로 적절한 확률을 자동으로 계산할 수 있으므로 새로운 다운스케일링 프로세스를 실험하는 데 이상적이다. 단점은 클래스 수가 많은 문제에 대한 전체 확률 벡터를 모델링하는 데 비용이 많이 들고 메모리에 적합하지 않을 수도 있다는 것이다. 저자들은 경험적으로 이미지 데이터에 대한 실험에서 두 parameterization 사이에 의미 있는 성능 차이가 없음을 발견했다.

Results

Order Agnostic Modelling

다음은 text8 (왼쪽)과 CIFAR-10 (오른쪽)에서 순서에 구애받지 않는 모델들의 성능을 비교한 표이다.

 

Lossless Compression

다음은 CIFAR-10에서의 무손실 압축 성능을 비교한 표이다.

Effects of Depth-Upscaling

다음은 오디오 (SC09)에 대한 깊이 업스케일링 성능을 비교한 표이다. 하나의 순서만 학습하는 WaveNet baseline은 7.77 bpd를 달성한다.


다음은 이미지 (CIFAR-10)에 대한 깊이 업스케일링 성능을 비교한 표이다.

Limitations

  1. ARDM이 텍스트에 대한 다른 모든 순서에 구애받지 않는 접근 방식을 능가하더라도 하나의 순서만 학습한 ARM의 성능에는 여전히 차이가 있다.
  2. 현재 설명에서 ARDM은 이산 변수를 모델링한다. 원칙적으로 연속 분포에 대한 흡수 프로세스를 정의할 수도 있다.
  3. 본 논문에서는 무손실 압축에서 코딩 길이와 직접적으로 일치하기 때문에 log-likelihood를 최적화하는 데 중점을 두었다. 그러나 샘플 품질과 같은 다른 목표를 위해 최적화할 때 다른 아키텍처 선택으로 더 나은 결과를 얻을 수 있다.