arXiv 2025. [Paper] [Blog] [Github]
Shen Nie, Fengqi Zhu, Zebin You, Xiaolu Zhang, Jingyang Ou, Jun Hu, Jun Zhou, Yankai Lin, Ji-Rong Wen, Chongxuan Li
Renmin University of China | Ant Group
14 Feb 2025

Introduction

LLM은 전적으로 생성 모델링 프레임워크 내에 있다. LLM은 maximum likelihood estimation을 통해 모델 분포 \(p_\theta (\cdot)\)를 최적화하거나, 이와 동등하게 두 분포 사이의 KL divergence 최소화를 통해 알려지지 않은 언어 분포 \(p_\textrm{data} (\cdot)\)를 모델링하는 것을 목표로 한다.

\[\begin{equation} \max_\theta \mathbb{E}_{p_\textrm{data} (x)} \log p_\theta (x) \; \Leftrightarrow \; \min_\theta \textrm{KL} (p_\textrm{data} (x) \| p_\theta (x)) \end{equation}\]

주요 접근 방식은 autoregressive modeling (ARM)에 의존한다. ARM은 일반적으로 next-token prediction 패러다임이라고 하며 모델 분포를 다음과 같이 정의한다.

\[\begin{equation} p_\theta (x) = p_\theta (x^1) \prod_{i=2}^L p_\theta (x^i \, \vert \, x^1, \cdot, x^{i-1}) \end{equation}\]

이 패러다임은 놀라울 정도로 효과적임이 입증되었고 현재 LLM의 기초가 되었다.

저자들은 autoregressive 패러다임이 현재 LLM이 보여주는 지능을 달성하는 유일한 방법이 아니라고 주장한다. 저자들의 핵심 통찰력은 LLM의 필수 속성을 근본적으로 뒷받침하는 것이 autoregressive 공식 자체가 아니라 생성 모델링 원리라는 것이다. 그러나 LLM의 특정 내재적 한계는 autoregressive한 특성에서 직접 추적할 수 있다.

특히, 저자들은 LLM의 scalability가 ARM의 고유한 결과라기보다는 주로 Transformer, 모델 및 데이터 크기, 생성 모델링 원리에 의한 결과라고 주장하였다. 비전 데이터에 대한 Diffusion Transformer의 성공은 이 주장을 뒷받침한다.

더욱이, instruction-following과 in-context learning의 능력은 구조적으로 일관된 언어적 task에 대한 모든 적절한 조건부 생성 모델의 본질적인 속성인 듯하며, ARM에 따른 이점이 아니다. 또한 ARM은 무손실 데이터 압축기로 해석될 수 있지만, 충분히 표현력이 뛰어난 확률적 모델은 유사한 능력을 달성할 수 있다.

그럼에도 불구하고 LLM의 autoregressive한 특성은 주목할 만한 한계점을 가지고 있다. 예를 들어, 순차적인 토큰별 생성은 높은 계산 비용을 초래하고, 왼쪽에서 오른쪽으로의 모델링은 reversal reasoning task에서 효과를 제한한다. 이러한 고유한 한계는 LLM이 더 길고 복잡한 task를 처리하는 데 제약을 준다.

이러한 통찰력에 의해 동기를 부여받아, 본 논문은 LLaDA (Large Language Diffusion with mAsking)를 도입하여 LLM이 보여주는 역량이 ARM을 넘어서는 생성 모델링 원리에서 나올 수 있는지 조사하였다. 기존 ARM과 달리, LLaDA는 discrete한 랜덤 마스킹 프로세스를 통합하고 mask predictor를 학습하여 reverse process를 근사화하는 masked diffusion model (MDM)을 활용한다. 이를 통해 LLaDA는 양방향 의존성을 가진 모델 분포를 구성하고 log-likelihood 하한을 최적화하여 기존 LLM에서 탐구되지 않았고 원칙적인 대안을 제공한다.

저자들은 데이터 준비, 사전 학습, supervised fine-tuning (SFT), 모델 평가로 구성된 표준 LLM 파이프라인을 채택하여 LLaDA를 8B 크기의 전례 없는 언어 diffusion model로 scaling하였다. 특히, LLaDA 8B는 0.13백만 H800 GPU 시간을 사용하여 2.3조 개의 토큰에서 처음부터 사전 학습되었고, 그 다음 4.5백만 쌍의 토큰에서 SFT가 수행되었다. LLaDA는 언어 이해, 수학, 코딩, 중국어를 포함한 다양한 task에서 scalability, in-context learning, instruction-following, reversal reasoning 능력을 보여준다.

Method

1. Probabilistic Formulation

ARM과 달리, LLaDA는 forward processreverse process를 통해 모델 분포 \(p_\theta (x_0)\)을 정의한다. Forward process는 $t = 1$에서 시퀀스가 ​​완전히 마스킹될 때까지 $x_0$의 토큰을 점차적으로 독립적으로 마스킹한다. $t \in (0, 1)$의 경우, 시퀀스 $x_t$는 부분적으로 마스킹되며 각각은 확률 $t$로 마스킹되거나 확률 $1-t$로 마스킹되지 않은 상태로 유지된다. Reverse process는 $t$가 1에서 0으로 이동함에 따라 마스킹된 토큰을 반복적으로 예측하여 데이터 분포를 복구한다.

LLaDA의 핵심은 mask predictor $p_\theta (\cdot \vert x_t)$로, $x_t$를 입력으로 받고 마스킹된 모든 토큰 $M$을 동시에 예측한다. Mask predictor는 마스킹된 토큰에서만 계산된 cross-entropy loss을 사용하여 학습된다.

\[\begin{equation} \mathcal{L} (\theta) = - \mathbb{E}_{t \sim U(0,1), x_0, x_t} \left[ \frac{1}{t} \sum_{i=1}^L \unicode{x1D7D9} [x_t^i = M] \log p_\theta (x_0^i \vert x_t) \right] \end{equation}\]

$x_0$는 학습 데이터에서 샘플링되고, $x_t$는 forward process에서 샘플링된다. Indicator function $\unicode{x1D7D9}[\cdot]$는 loss가 마스킹된 토큰에 대해서만 계산되도록 보장한다.

일단 학습되면, mask predictor로 parameterize된 reverse process를 시뮬레이션하고 모델 분포 $p_\theta (x_0)$를 $t = 0$에서 유도된 marginal distribution으로 정의할 수 있다. 특히, 위의 loss는 모델 분포의 negative log-likelihood (NLL)에 대한 상한으로 입증되었기 때문에, 이 loss를 생성 모델링을 위한 목적 함수로 사용할 수 있다.

\[\begin{equation} - \mathbb{E}_{p_\textrm{data} (x_0)} [\log p_\theta (x_0)] \le \mathcal{L} (\theta) \end{equation}\]

Masked language model들은 고정된 마스킹 비율을 사용하는 반면, LLaDA는 0과 1 사이에서 무작위로 변하는 마스킹 비율을 사용한다. 이 미묘한 차이는 특히 scale에 따라 상당한 의미를 갖는다. 위 식에서 볼 수 있듯이, LLaDA는 LLM과 유사하게 자연스럽게 in-context learning을 수행할 수 있는 잠재력을 가진 원칙적인 생성 모델이며, Fisher consistency를 보장하여 대규모 데이터와 모델에 대한 강력한 scalability를 시사한다.

2. Pre-training


LLaDA는 Transformer를 mask predictor로 사용하는데, 그 아키텍처는 기존 LLM과 유사하다. 그러나 LLaDA는 causal mask를 사용하지 않으며, 이는 예측을 위해 전체 입력을 볼 수 있게 하기 위함이다.

저자들은 서로 다른 크기의 두 가지 LLaDA 모델, 1B 모델과 8B 모델을 학습시켰다. LLaDA 8B는 LLaMA3 8B의 대부분의 hyperparameter를 따른다. LLaDA는 KV caching과 호환되지 않아 key head와 value head의 수가 달라지기 때문에 grouped query attention 대신 일반적인 multi-head attention을 사용한다. 결과적으로 attention layer에는 더 많은 파라미터가 있으며 비슷한 모델 크기를 유지하기 위해 FFN 차원을 줄였다. 또한, vocabulary 크기는 데이터에 맞게 적용된 tokenizer로 인해 약간 다르다.

LLaDA 모델은 2.3조 개의 토큰으로 구성된 데이터셋에서 사전 학습되었으며, 특별한 테크닉을 사용하지 않고도 기존의 LLM과 동일한 데이터 프로토콜을 준수하였다. 데이터는 온라인 코퍼스에서 파생되었으며, 일반 텍스트 외에도 고품질 코드, 수학, 다국어 데이터를 포함한다.

학습 시퀀스 $x_0$에 대하여, $t \in [0, 1]$을 무작위로 샘플링하고, 각 토큰을 동일한 확률 $t$로 독립적으로 마스킹하여 $x_t$를 구한다. 또한, 가변 길이의 데이터를 처리하는 LLaDA의 능력을 향상시키기 위해 사전 학습 데이터의 1%를 범위 [1, 4096]에서 균일하게 샘플링된 무작위 길이로 설정한다.

  • 13만 H800 GPU hour
  • 시퀀스 길이: 4,096으로 고정
  • learning rate: Warmup-Stable-Decay
    • 처음 2,000 iteration에서 0에서 $4 \times 10^{-4}$로 선형적으로 증가
    • 1.2조 개의 토큰을 처리할 동안 $4 \times 10^{-4}$로 유지
    • $1 \times 10^{-4}$로 감소시킨 후 8,000억개의 토큰 동안 유지
    • 마지막 3,000억 개의 토큰 동안 $1 \times 10^{-4}$에서 $1 \times 10^{-5}$로 선형적으로 감소
  • weight decay: 0.1
  • global batch size: 1,024 (GPU당 4)
  • optimizer: AdamW

3. Supervised Fine-Tuning


저자들은 프롬프트-응답 쌍 데이터 $(p_0, r_0)$를 사용하여 supervised fine-tuning (SFT)을 통해 LLaDA의 instruction following 능력을 향상시켰다. 이것은 LLM에 대한 가장 간단하고 가장 기본적인 사후 학습 방법이며, 사전 학습에서 \(p_\theta (x_0)\) 대신 조건부 분포 \(p_\theta (r_0 \vert p_0)\)를 모델링해야 한다.

구현은 사전 학습과 유사하다. 프롬프트는 변경하지 않고 응답의 토큰을 $x_0$에서와 같이 독립적으로 마스킹한다. 그런 다음 프롬프트와 마스킹된 응답 $r_t$를 모두 사전 학습된 mask predictor에 공급하여 SFT에 대한 loss를 계산한다.

\[\begin{equation} -\mathbb{E}_{t \sim U(0,1), p_0, r_0, r_t} \left[ \frac{1}{t} \sum_{i=1}^{L^\prime} \unicode{x1D7D9} [r_t^i = M] \log p_\theta (r_0^i \vert p_0, r_t) \right] \end{equation}\]

이 접근 방식은 사전 학습과 완벽하게 호환된다. $p_0$와 $r_0$의 concatenation은 사전 학습 데이터 $x_0$로 사용되며, $p_0$와 $r_t$의 concatenation은 마스킹된 버전 $x_t$로 사용된다. 프로세스는 사전 학습과 동일하며 유일한 차이점은 모든 마스킹된 토큰이 $r_0$ 부분에 나타난다는 것이다.

LLaDA 8B 모델은 450만 쌍으로 구성된 데이터셋에서 SFT를 거친다. 사전 학습과 마찬가지로 데이터 준비와 학습은 기존 LLM에서 사용되는 SFT 프로토콜을 따르며, LLaDA의 성능을 최적화하기 위한 추가 테크닉을 도입하지 않았다. 데이터셋은 코드, 수학, instruction following, 구조화된 데이터 이해 등 여러 도메인에 걸쳐 있다. 저자들은 모든 데이터에서 길이가 동일하도록 각 mini-batch의 짧은 쌍 끝에 EOS 토큰을 추가하였다. 학습 중에는 EOS 토큰을 일반 토큰으로 처리하고 샘플링 중에는 제거하여 LLaDA가 응답 길이를 자동으로 제어할 수 있도록 하였다.

  • epoch: 3
  • learning rate: Warmup-Stable-Decay
    • 처음 50 iteration에서 0에서 $2.5 \times 10^{-5}$로 선형적으로 증가
    • $2.5 \times 10^{-5}$로 유지
    • 마지막 10%의 iteration에서는 $2.5 \times 10^{-5}$에서 $2.5 \times 10^{-6}$으로 선형적으로 감소
  • weight decay: 0.1
  • global batch size: 256 (GPU당 2)

4. Inference


생성 모델인 LLaDA는 새로운 텍스트를 샘플링하고 후보 텍스트의 likelihood를 평가할 수 있다.

프롬프트 $p_0$가 주어지면 reverse process를 discretize하여 모델 분포 \(p_\theta (r_0 \vert p_0)\)에서 샘플링하고 완전히 마스킹된 응답에서 시작한다. 샘플링 step의 수는 hyperparameter로, 효율성과 샘플 품질 간의 trade-off를 제공한다. 기본적으로 균일하게 분포된 timestep을 사용한다. 또한 생성 길이도 hyperparameter로 처리되어 샘플링 프로세스 시작 시 완전히 마스킹된 문장의 길이를 지정한다. 사전 학습과 SFT는 모두 가변 길이의 데이터셋을 사용하여 수행되므로 최종 결과는 이 길이 hyperparameter에 영향을 받지 않는다.

시간 $t \in (0, 1]$에서 $s \in [0, t)$까지의 중간 step에서, $p_0$와 $r_t$를 모두 mask predictor에 입력하고 모든 마스킹된 토큰을 동시에 예측한다. 그런 다음 $r_s$를 얻기 위해 예측된 토큰의 $\frac{s}{t}$를 다시 마스킹하여 정확한 샘플링을 위해 reverse process의 전환이 forward process와 일치하도록 한다.

원칙적으로 remasking 전략은 완전히 무작위적이어야 한다. 하지만 저자들은 LLM에서 샘플링의 annealing trick에서 영감을 얻어, 두 가지 deterministic하지만 효과적인 remasking 전략을 탐구하였다. 구체적으로 MaskGIT과 유사하게 예측에 따라 가장 낮은 신뢰도로 예측된 ​​토큰의 $\frac{s}{t}$를 다시 마스킹하는데, 이를 low-confidence remasking이라고 한다. 또한 SFT 이후의 LLaDA의 경우, 시퀀스를 여러 블록으로 나누어 왼쪽에서 오른쪽으로 생성할 수 있으며, 이를 semi-autoregressive remasking이라고 한다. 각 블록 내에서는 reverse process를 적용하여 샘플링을 수행한다.

조건부 likelihood 평가의 경우, SFT의 loss를 활용할 수도 있지만, 다음과 같은 동등한 형태가 더 낮은 분산을 보이고 평가에 더 안정적이다.

\[\begin{equation} -\mathbb{E}_{l \sim U(\{1, \ldots, L\}), r_0, r_l} \left[ \frac{L}{l} \sum_{i=1}^{L} \unicode{x1D7D9} [r_l^i = M] \log p_\theta (r_0^i \vert p_0, r_l) \right] \end{equation}\]

$r_l$은 $r_0$에서 $l$개의 토큰을 균일하게 샘플링하여 얻는다. 또한, unsupervised classifier-free guidance를 사용한다.

Experiments

1. Scalability of LLaDA on Language Tasks

다음은 LLaDA의 scalability를 ARM과 비교한 그래프이다.


2. Benchmark Results

다음은 사전 학습 후의 LLaDA를 다른 사전 학습된 LLM들과 비교한 결과이다.


다음은 SFT 후의 LLaDA를 다른 사후 학습된 LLM들과 비교한 결과이다.

3. Reversal Reasoning and Analyses

다음은 시 완성 task에 대한 결과를 비교한 표이다.

4. Case Studies

다음은 샘플링 프로세스와 멀티 라운드 대화의 예시이다.