arXiv 2025. [Paper]
Ali Behrouz, Peilin Zhong, Vahab Mirrokni
Google Research
31 Dec 2024

Introduction

Transformer는 순수한 attention 기반 아키텍처로, 주로 in-context learning과 대규모 학습 능력 덕분에 시퀀스 모델링에서 SOTA 모델로 확고히 자리 잡았다. Transformer의 주요 구성 요소인 attention 모듈은 연상 기억 블록으로 기능하여 key-value 관계를 저장하고 query(검색 신호)와 key(컨텍스트) 간의 쌍별 유사성을 계산하여 검색하는 방법을 학습한다. 따라서 설계상 Transformer의 출력은 현재 context window에서 토큰 간의 직접적인 의존성에만 의존한다. 그러나 이러한 의존성 모델링의 컨텍스트 길이의 제곱에 비례하는 복잡도를 수반한다. 복잡한 실제 task에서 context window가 매우 커질 수 있으므로 이러한 다운스트림 task에서 Transformer를 적용하는 것은 어렵다.

Transformer의 확장성 문제를 극복하기 위해 최근 연구들에서는 다양한 linear Transformer를 설계하는 것을 목표로 하였지만, 성능이 Transformer보다 좋지 못하며 매우 긴 컨텍스트는 작은 벡터 값 또는 행렬 값 상태로 적절하게 압축될 수 없다. 또한 대부분의 기존 아키텍처는 일반화, 길이 외삽, 추론을 처리할 때 어려움에 직면한다.

신경심리학의 기억과 학습에 대한 일반적인 정의에서 영감을 얻은 대부분의 기존 아키텍처는 기억을 입력으로 인한 파라미터 업데이트로 간주하고, 학습을 주어진 목적 함수에 대하여 효과적이고 유용한 기억을 습득하는 프로세스로 정의하였다. 이러한 관점에서 Transformer의 key와 value 행렬은 모델의 메모리 역할을 하며, 모델은 key와 value를 압축 없이 메모리에 추가하여 메모리를 업데이트하고, query 벡터와 key 벡터의 유사성을 찾아 query 벡터에 대한 기억을 검색한 다음, 이를 사용하여 출력에 대한 value 벡터에 가중치를 부여한다.

이러한 관점은 기존 패러다임과 그 중요한 차이점을 더 잘 이해하고, 더 효과적인 아키텍처를 설계하는 데 도움이 될 수 있다. 예를 들어, linear Transformer는 과거 데이터를 고정된 크기의 행렬 값 메모리로 압축하는 반면, Transformer는 모든 과거 데이터를 압축 없이 보관한다. 따라서 이러한 관점은 다음과 같은 질문을 하게 한다.

Q1. 메모리에 적합한 구조는 무엇인가?
Q2. 적절한 메모리 업데이트 메커니즘은 무엇인가?
Q3. 적합한 메모리 검색 프로세스는 무엇인가?

인간의 기억은 하나의 프로세스도 아니고 하나의 함수도 아니다. 기억은 단기 기억, 작업 기억, 장기 기억과 같은 시스템의 연합으로, 각각은 다른 구조로 다른 기능을 수행하며, 각각 독립적으로 작동할 수 있다. 이 사실은 다음과 같은 질문을 던지게 한다.

Q4. 서로 연결된 다양한 메모리 모듈을 통합하는 효율적인 아키텍처를 설계하는 방법은 무엇인가?

마지막으로, 기억을 저장하는 것은 과거 데이터의 추상화를 인코딩하고 저장해야 하는 프로세스이다. 파라미터가 선형 방식으로 데이터를 인코딩하는 단일 벡터나 행렬이 과거 데이터를 저장하기에 충분하다고 가정하는 것은 지나치게 단순화한 것일 수 있다.

Q5. 과거 데이터를 효과적으로 저장/기억하려면 신경망이 필요한가?

본 논문에서는 test-time에 기억하는 법을 효율적이고 효과적으로 배울 수 있는 long-term memory 모듈을 설계하여 위의 5가지 질문에 답하는 것을 목표로 한다. 또한 그 설계를 바탕으로 아키텍처에 어떻게 통합할 수 있는지 논의하였다.

저자들은 test-time에 데이터를 파라미터로 기억/저장하는 방법을 학습하는 long-term memory를 제시하였다. 인간의 장기 기억 시스템에서 영감을 받아, 저자들은 기대에 어긋나는 사건(놀라움)이 더 기억에 남도록 이 메모리 모듈을 설계하였다. 이를 위해, 입력에 대한 신경망의 기울기로 입력의 놀라움을 측정하였다.

제한된 메모리를 더 잘 처리하기 위해, 메모리 크기의 비율과 데이터 놀라움의 양을 고려하는 감쇠(decay) 메커니즘을 제시하여 더 나은 메모리 관리를 제공한다. 흥미롭게도, 이 메커니즘은 mini-batch gradient descent, 모멘텀, weight decay를 사용하여 메타 신경망을 최적화하는 것과 동일하다. 더 많은 matmul 연산을 사용하기 위해 mini-batch gradient descent를 텐서화(tensorize)하는 것을 기반으로, 저자들은 long-term memory를 학습시키기 위한 빠르고 병렬화가 가능한 알고리즘을 제시하였다.

Long-term memory를 설계한 후 남은 중요한 질문은 메모리를 딥러닝 아키텍처에 효과적이고 효율적으로 통합하는 방법이다. 저자들은 세 개의 hyper-head로 구성된 Titans를 제시하였다.

  1. Core: Short-term memory로 구성되며 데이터 처리의 주요 흐름을 담당한다. 제한된 window 크기의 attention을 사용한다.
  2. Long-term Memory: 오랜 과거 데이터를 저장/기억하는 것을 담당한다.
  3. Persistent Memory: Task에 대한 지식을 인코딩하는 학습 가능하지만 날짜와 독립적인 파라미터 집합이다.

마지막으로 저자들은 메모리를 컨텍스트, 레이어, gated branch로 통합하는 Titans의 세 가지 변형을 제시하였다.

Titans 아키텍처는 포괄적인 벤치마크에서 모든 RNN 아키텍처보다 성능이 뛰어나다. 또한 동일한 context window를 사용하는 Transformer보다 성능이 뛰어나고 전체 컨텍스트를 사용하는 Transformer와 비슷한 성능을 보여준다.

Learning to Memorize at Test Time

본 논문에서는 장기 기억의 부족을 극복하고 모델이 정보를 학습하고, 잊고, 검색할 수 있도록 test-time에 기억하는 법을 배우는 메타 모델인 long-term memory 모듈을 제시하였다.

1. Long-term Memory

Long-term memory 모듈을 설계하려면 과거에 대한 추상화를 파라미터로 인코딩할 수 있는 모델이 필요하다. 간단한 아이디어는 LLM을 학습시키고 학습 데이터를 암기하도록 기대하는 것이다. 그러나 암기는 모델 일반화를 제한하고 개인 정보 문제를 야기하여 test-time 성능이 저하되기 때문에 바람직하지 않다. 게다가 학습 데이터의 암기는 데이터가 분포에서 벗어났을 수 있는 test-time에 도움이 되지 않을 수 있다.

따라서 test-time에 데이터를 암기/잊는 방법을 학습하는 온라인 메타 모델이 필요하다. 모델은 암기가 가능한 함수를 학습하지만 학습 데이터에 overfitting되지 않아 test-time에 더 나은 일반화가 이루어진다.

학습 프로세스와 Surprise metric

장기 기억을 학습시키는 핵심 아이디어는 장기 기억을 학습시키는 것을 온라인 학습 문제로 취급하는 것이다. 여기서 과거 정보 $x_1, \ldots, x_{t-1}$을 long-term memory 모듈 \(\mathcal{M}_t\)의 파라미터로 압축하는 것을 목표로 한다. 기대에 어긋나는 사건, 즉 놀라운 사건은 인간에게 더 기억에 남는다. 모델에 대한 놀라움의 간단한 정의는 입력에 대한 기울기일 수 있다. 기울기가 클수록 입력 데이터와 과거 데이터 간의 차이가 커진다. 따라서 이 놀라움 점수를 사용하여 메모리를 다음과 같이 업데이트할 수 있다.

\[\begin{equation} \mathcal{M}_t = \mathcal{M}_{t-1} - \theta_t \nabla \ell (\mathcal{M}_{t-1}; x_t) \end{equation}\]

그러나 이러한 surprise metric은 큰 놀라운 순간 이후에 나오는 중요한 정보를 놓치는 결과를 초래할 수 있다. 즉, 여러 번의 놀라운 단계 후에 기울기가 매우 작아져 평평한 local minima에 저장되고 시퀀스의 일부에 대한 정보가 누락될 수 있다. 인간의 기억 관점에서 볼 때, 사건을 기억될 수 있지만 장기간에 걸쳐 지속적으로 놀라게 하지는 않을 수 있다. 그 이유는 초기 순간이 장기간에 걸쳐 충분히 놀랍기 때문에 전체 기간 동안 기억하게 되기 때문이다. 위의 surprise metric을 개선하기 위해 surprise metric을 매우 최근 과거의 놀라움 양을 측정하는 past surprise와, 들어오는 데이터의 놀라움을 측정하는 momentary surprise로 나눈다.

\[\begin{aligned} \mathcal{M}_t &= M_{t-1} + S_t \\ S_t &= \eta_t \underbrace{S_{t-1}}_{\textrm{Past Surprise}} - \theta_t \underbrace{\nabla \ell (\mathcal{M}_{t-1}; x_t)}_{\textrm{Momentary Surprise}} \end{aligned}\]

흥미롭게도, 이 공식은 모멘텀이 $S_t$인 gradient descent와 비슷하다. 따라서 여기서 모멘텀은 시간(시퀀스 길이)에 따른 놀라움의 기억으로 작용한다. 이 공식에서 \(\eta_t\)는 데이터에 따라 달라지는 surprise decay로 ($x_t$의 함수), 시간에 따른 놀라움의 감소를 제어하고, \(\theta_t\)는 momentary surprise를 데이터에 의존적인 방식으로 최종 surprise metric에 얼마나 통합해야 하는지 제어한다.

이 데이터 의존성은 이 디자인에서 특히 중요하다. 이전 토큰의 놀라움이 다음 토큰의 놀라움에 영향을 미치는 데 필요할 수 있지만, 모든 토큰이 관련이 있고 동일한 맥락에 있는 경우 대부분 유효하다. 따라서 \(\eta_t \rightarrow 0\)으로 설정하여 마지막 놀라움을 무시하거나, \(\eta_t \rightarrow 1\)로 설정하여 마지막 놀라움을 완전히 통합해야 하는지 여부를 제어할 수 있다.

목적 함수

위의 surprise metric은 loss function $\ell$에 기반을 두고 있으며, 이는 test-time에 메모리가 그 역할을 하는 것을 배우는 목적 함수이다. 즉, 메모리 모듈은 loss function $\ell$에 기반한 함수를 학습하는 메타 모델이다. 본 논문에서는 연상 기억(associative memory)에 초점을 맞추고 있으며, 과거 데이터를 key와 value의 쌍으로 저장하는 것을 목표로 한다. Transformer와 유사하게 $x_t$가 주어지면 두 개의 linear layer를 사용하여 $x_t$를 key와 value로 projection한다.

\[\begin{equation} \textbf{k}_t = x_t W_K, \; \textbf{v}_t = x_t W_V, \quad W_K, W_V \in \mathbb{R}^{d_\textrm{in} \times d_\textrm{in}} \end{equation}\]

다음으로, 메모리 모듈이 key와 value 사이의 연관성을 학습하도록 하기 위해 loss를 다음과 같이 정의한다.

\[\begin{equation} \ell (\mathcal{M}_{t-1}; x_t) = \| \mathcal{M}_{t-1} (\textbf{k}_t) - \textbf{v}_t \|_2^2 \end{equation}\]

메타 모델(메모리)의 내부 루프에서 위의 loss function을 최적화함으로써, 모델은 test-time에 key와 value 사이의 매핑을 기억하는 방법을 학습한다. Meta-learning 모델과 유사하게 메모리의 학습은 내부 루프에서 이루어지므로 파라미터 $W_K$와 $W_v$는 위의 loss function에서 hyperparameter이다. 따라서 내부 루프에서 $\mathcal{M}$의 가중치를 최적화하고, 외부 루프에서 전체 아키텍처의 다른 파라미터를 최적화한다.

망각 메커니즘

매우 큰 시퀀스(ex. 수백만 개의 토큰)를 다룰 때, 어떤 과거 정보를 잊어야 하는지 관리하는 것이 중요하다. 이를 위해 적응적 망각 메커니즘을 사용하여 메모리가 더 이상 필요하지 않은 정보를 잊을 수 있도록 하여 메모리의 제한된 용량을 더 잘 관리한다.

다음 토큰 $x_t$가 주어지면 업데이트 규칙을 다음과 같이 수정한다.

\[\begin{aligned} \mathcal{M}_t &= (1 - \alpha_t) \mathcal{M}_{t-1} + S_t \\ S_t &= \eta_t S_{t-1} - \theta_t \nabla \ell (\mathcal{M}_{t-1}; x_t) \end{aligned}\]

여기서 \(\alpha_t \in [0, 1]\)은 메모리를 유연하게 제어하는 gating 메커니즘으로, 얼마나 많은 정보를 잊어야 하는지 결정한다. 예를 들어, \(\alpha_t \rightarrow 0\)으로 두면 과거 데이터에 영향을 주지 않고 메모리를 업데이트할 수 있고, \(\alpha_t \rightarrow 1\)로 두면 전체 메모리를 지울 수 있다. 이 weight decay 메커니즘은 RNN의 gating 메커니즘과 밀접한 관련이 있다.

메모리 아키텍처

본 논문에서는 long-term memory장의 아키텍처로서 \(L_\mathcal{M} \ge 1\)개의 레이어를 가진 간단한 MLP에 초점을 맞추었다.

벡터 값 또는 행렬 값 메모리를 사용할 때 메모리 모듈은 과거 데이터를 압축하여 한 줄에 맞춘다. 즉, meta-learning 또는 온라인 학습 관점에서 행렬 값 메모리 $\mathcal{M} = W \in \mathbb{R}^{d_\textrm{in} \times d_\textrm{in}}$을 사용하는 것은

\[\begin{equation} \ell (W_{t-1}; x_t) = \| W_{t-1} \textbf{k}_t - \textbf{v}_t \|_2^2 \end{equation}\]

를 최적화하는 것과 동일하며, 이는 온라인 linear regresssion 목적 함수이므로 최적해는 과거 데이터의 의존성이 선형이라고 가정한다. 반면에 \(L_\mathcal{M} \ge 2\)인 딥 메모리 모듈은 선형 모델보다 표현력이 뛰어나다 실제로 더 효과적이다.

메모리 검색

단순히 가중치 업데이트 없이 forward pass만을 사용하여 query에 해당하는 메모리를 검색한다. 입력 $x_t$가 주어지면 linear layer $W_Q$를 사용하여 입력을 projection시킨 후, 메모리에서 해당 정보 $y_t$를 검색한다.

\[\begin{equation} y_t = \mathcal{M}^\ast (\textbf{q}_t), \quad \textrm{where} \; \textbf{q}_t = x_t W_Q \end{equation}\]

2. How to Parallelize the Long-term Memory Training

위에서 논의한 대로, long-term memory 모듈의 설계는 모멘텀과 weight decay를 사용한 gradient descent로 loss function \(\ell (\mathcal{M}_{t-1}; x_t)\)를 최적화하여 메타 모델을 학습시키는 것과 동일하다. 따라서 시퀀스 길이가 $N$일 때, 이론적으로 long-term memory 모듈의 학습에는 $O(N)$ FLOP이 필요하다. 그러나 실제로는 학습 프로세스를 병렬화하고 하드웨어 가속기(TPU, GPU)를 최대한 활용하려면 프로세스를 텐서화(tensorize)하고 더 많은 matmul을 사용해야 한다.

Mini-batch gradient descent, 데이터에 따라 다른 learning rate, weight decay를 사용하여 내부 루프에서 가중치를 계산하는 것을 matmuls와 sum만 사용하도록 재구성할 수 있다. 시퀀스를 크기가 $b \ge 1$인 chunk로 분할하면 mini-batch gradient descent를 다음과 같이 쓸 수 있다.

\[\begin{equation} \mathcal{M}_t = (1 - \alpha_t) \mathcal{M}_{t-1} - \theta_t \nabla \ell (\mathcal{M}_{t-1}; x_t) = \beta_t \mathcal{M}_0 - \sum_{i=1}^t \theta_i \frac{\beta_t}{\beta_i} \nabla \ell (\mathcal{M}_{t^\prime}; x_i) \\ \textrm{where} \; t^\prime = t - \textrm{mod}(t, b), \; \beta_i = \prod_{j=1}^i (1 - \alpha_j) \end{equation}\]

단순성을 위해, 첫 번째 chunk, 즉 $t = b$에 집중하고 ($t^\prime = 0$), \(\mathcal{M}_t = W_t\)가 선형인 경우에 대한 프로세스만 살펴보자. $N_p \ge 2$인 MLP의 프로세스도 비슷하다. Loss function은 다음과 같이 쓸 수 있다.

\[\begin{equation} \nabla \ell (W_0; x_t) = (W_0 x_t - x_t) x_t^\top \; \Rightarrow \; \sum_{i=1}^b \theta_i \frac{\beta_b}{\beta_i} \nabla \ell (W_0; x_i) = \Theta_b \textbf{B}_b (W_0 X - X) X^\top \\ \textrm{where} \; \Theta_b = \textrm{diag}([\theta_1 \; \theta_2 \; \ldots \; \theta_b]), \; \textbf{B}_b = \textrm{diag}([\frac{\beta_b}{\beta_1} \; \frac{\beta_b}{\beta_2} \; \ldots \; \frac{\beta_b}{\beta_b}]) \end{equation}\]

참고로, $k = 1, \ldots, N/b$에 대해 모든 \(\Theta_{kB}\)와 \(\textbf{B}_{kb}\)를 저장할 필요는 없다. 대신, 각 chunk에 대해 이러한 행렬을 저장하여 메모리를 덜 사용한다.

다음으로, 이 표현을 확장하여 모멘텀 항도 통합할 수 있다. 모멘텀이 있는 chunk별 gradient descent에서 모멘텀 항은 다음과 같다.

\[\begin{equation} S_t = \eta_t S_{t-1} - \theta_t u_t \quad \textrm{where} \; u_t = \nabla \ell (\mathcal{M}_{t^\prime}, x_t) \end{equation}\]

모든 $u_t$를 동시에 계산할 수 있으므로 parallel associative scan을 사용하여 이 chunk에서 $S_t$를 계산할 수 있다.

Chunk의 함수로서의 파라미터

\(\alpha_t\), \(\theta_t\), \(\eta_t\)와 같은 파라미터를 입력 토큰 $x_t$의 함수로 만드는 대신, 이를 해당 chunk의 함수로 만들 수 있다. 표현력은 떨어지지만 이 방법은 학습을 더욱 빠르게 만드는 데 도움이 될 수 있다. 이 경우 각 chunk의 $\alpha$, $\theta$, $\eta$ 각각에 대해 동일한 값을 사용한다. 그렇게 되면 하나의 스칼라로 $\Theta$를 저장할 수 있으며, $S_t$를 global convolution으로 계산할 수 있다. 실험에서는 이 방법을 사용하지 않고 파라미터를 입력 토큰의 함수로 만들었다.

3. Persistent Memory

Long-term memory는 출력이 컨텍스트에 완전히 의존한다. 따라서 long-term memory 외에도 학습 가능하지만 입력과 독립적인 파라미터 집합을 사용하여 task 관련 메모리 역할을 하도록 한다. 이러한 유형의 메모리를 일반적으로 persistent memory 또는 meta-memory라 부른다.

$N_p \ge 1$이 주어지면 학습 가능한 파라미터 \([p_1 \; p_2 \; \ldots \; p_{N_p}]\)를 사용하며 시퀀스의 시작 부분에 추가한다. 즉, 입력을 다음과 같이 수정한다.

\[\begin{equation} x_\textrm{new} = [p_1 \; p_2 \; \ldots \; p_{N_p}] \, \| \, x \end{equation}\]

(\(\|\)는 concatenation)

다음 세 가지 관점에서 persistent memory를 사용하는 동기를 설명할 수 있다.

메모리 관점

앞서 논의했듯이, long-term memory는 모든 파라미터가 입력에 따라 달라지는 전후 관계상 기억(contextual memory)이다. 그러나 효과적인 기억 시스템은 task 지식의 추상화를 저장하기 위해 입력에 독립적인 파라미터가 필요하다. 즉, task를 마스터하려면 task를 수행하는 방법에 대한 지식을 암기해야 하며, persistent memory는 그러한 지식을 저장하는 역할을 한다.

Feedforward Network 관점

Transformer 아키텍처에서 attention 모듈 뒤에 fully connected layer들이 있는데, 이는 attention 가중치와 유사하지만 데이터에 독립적인 파라미터를 갖는다. 즉, fully connected layer의 ReLU를 Softmax로 대체하면 가중치가 데이터 독립적인 attention과 비슷한 가중치가 생성될 수 있다.

\[\begin{equation} \textrm{FFN} (x) = W_V \textrm{Softmax} (W_K x) \end{equation}\]

$W_K$와 $W_V$는 입력에 독립적일 때 attention 모듈의 $K$와 $V$ 행렬과 유사하게 작동한다. Persistent memory 가중치는 동일한 기능을 가질 것으로 예상되며, 이는 시퀀스의 첫 번째 부분에서 사용하면 입력에 독립적인 attention 가중치를 갖게 됨을 의미한다.

기술적 관점

Causal mask를 사용한 attention은 시퀀스의 초기 토큰에 대한 편향을 가지므로 attention 가중치는 거의 항상 초기 토큰에 대해 매우 활성화되어 성능 손상이 발생한다. 기술적 관점에서 시퀀스 시작 시 이러한 학습 가능한 파라미터는 attention 가중치를 보다 효과적으로 재분배하여 이러한 효과를 완화할 수 있다.

How to Incorporate Memory?

Long-term memory를 딥러닝 아키텍처에 효과적이고 효율적으로 통합하는 방법은 무엇일까? 메모리 관점에서 볼 때, Transformer의 $K$, $V$ 행렬 쌍은 의존성에 대한 정확한 모델링과 제한된 context window로 인해 현재 context window 크기에 attention을 하는 short-term memory 모듈로 해석된다. 반면, 데이터에서 지속적으로 학습하고 가중치에 저장할 수 있는 neural memory는 long-term memory의 역할을 할 수 있다.

저자들은 세 가지의 Titans 아키텍처를 제안하여 위의 질문에 답하고자 하였다. 각 아키텍처는 고유한 장단점을 가지고 있으며 매우 긴 컨텍스트에서 효율성과 효과성 간의 상충 관계를 보여준다.

1. Memory as a Context (MAC)


첫 번째 아키텍처는 메모리를 현재 정보에 대한 컨텍스트로 취급한다. 긴 시퀀스 \(x \in \mathbb{R}^{N \times d_\textrm{in}}\)이 주어지면, 먼저 시퀀스를 고정 크기의 세그먼트 \(\{S^{(i)}\}_{i=1}^{N/C}\)로 chunk한다. 들어오는 세그먼트 $S^{(t)}$가 주어지면, 이를 현재 컨텍스트로 간주하고 과거 세그먼트를 과거 정보로 간주한다. 따라서 \(\mathcal{M}_{t-1}\)을 세그먼트 $S^{(t)}$ 이전의 long-term memory 상태라고 하고, 입력 컨텍스트를 \(\mathcal{M}_{t-1}\)에 대한 query로 사용하여 long-term memory에서 해당 정보를 검색한다. 즉, $S^{(t)}$에 해당하는 과거 정보를 다음과 같이 검색한다.

\[\begin{equation} h_t = \mathcal{M}_{t-1}^\ast (\textbf{q}_t), \quad \textrm{where} \; \textbf{q}_t = S^{(t)} W_Q \end{equation}\]

다음으로, 이러한 기록 정보를 persistent memory 파라미터와 함께 attention 모듈에 대한 입력 시퀀스로 사용한다.

\[\begin{aligned} \tilde{S}^{(t)} &= [p_1 \; p_2 \; \ldots \; p_{N_p}] \, \| \, h_t \, \| \, S^{(t)} \\ y_t &= \textrm{Attn} (\tilde{S}^{(t)}) \end{aligned}\]


전체 시퀀스에 걸친 attention map의 구조는 위 그림에 나와 있다. 그런 다음 $y_t$를 사용하여 forward pass를 통해 long-term memory 모듈을 업데이트한다.

\[\begin{aligned} \mathcal{M}_t &= \mathcal{M}_{t-1} (y_t) \\ o_t &= y_t \otimes \mathcal{M}_t^\ast (y_t) \end{aligned}\]

이 아키텍처는 두 가지 주요 장점이 있다.

  1. Attention은 과거와 현재 컨텍스트를 모두 가지고 있기 때문에 현재 데이터가 주어졌을 때 long-term memory 정보가 필요한지 여부를 결정할 수 있다.
  2. Attention 모듈은 long-term memory가 현재 컨텍스트에서 유용한 정보만 저장하도록 돕는다. 즉, 각 세그먼트의 모든 토큰이 유용한 것은 아니며, 모든 토큰을 기억하면 메모리 오버플로가 발생할 수 있다. 따라서 attention은 메모리가 어떤 정보가 유용한지 이해하도록 돕고 메모리 용량을 더 잘 관리한다.

Test-time에 persistent memory는 task에 대한 지식을 인코딩하기 때문에 고정되며, attention 모듈 가중치는 in-context learner이고, long-term memory 모듈은 test-time에도 정보를 학습하고 있다. 즉, 가중치가 과거 데이터를 인코딩하기 때문에 test-time에도 long-term memory의 가중치를 업데이트한다.

2. Memory as Gating (MAG)


MAG 아키텍처는 한 branch에서는 long-term memory를 업데이트하기 위해 입력 데이터를 직접 사용하고 두 번째 branch에서는 sliding window attention (SWA)을 사용한다.

\[\begin{aligned} \tilde{x} &= [p_1 \; p_2 \; \ldots \; p_{N_p}] \, \| \, x \\ y &= \textrm{SW-Attn}^\ast (\tilde{x}) \\ o &= y \otimes \mathcal{M} (\tilde{x}) \end{aligned}\]

(\(\textrm{SW-Attn}^\ast\)은 prefix가 있는 sliding window attention, $\mathcal{M}(x)$는 시퀀스 토큰에 대한 모든 재귀 이후의 최종 메모리 출력, $\otimes$는 non-linear gating)

저자들은 학습 가능한 벡터 가중치를 사용하여 출력 $y$와 $\mathcal{M}(\tilde{x})$를 정규화한 다음 non-linearity $\sigma(\cdot)$를 수행하였다.


MAG의 전체 attention mask는 위 그림에 나와 있다. Sliding window attention은 정확한 단기 기억으로 작동하는 반면, long-term memory 모듈은 모델의 희미해져 가는 기억으로 작동한다. 이 아키텍처 디자인은 head들의 구조가 다른 multi-head 아키텍처로 볼 수도 있다.

3. Memory as a Layer (MAL)


MAL 아키텍처는 long-term memory를 신경망의 레이어로 사용한다. 이 아키텍처 디자인은 full attention 또는 sliding window attention을 사용하여 recurrent model을 쌓는다. 입력 $x$가 주어지면 다음과 같다.

\[\begin{aligned} \tilde{x} &= [p_1 \; p_2 \; \ldots \; p_{N_p}] \, \| \, x \\ y &= \mathcal{M} (\tilde{x}) \\ o &= \textrm{SW-Attn} (y) \end{aligned}\]

($\textrm{SW-Attn}$은 sliding window attention)

이 디자인의 주요 단점은 모델의 파워가 각 레이어에 의해 제한되어 attention과 long-term memory 모듈의 보완적 데이터 처리를 활용할 수 없다는 것이다. 실험에서는 이 디자인에서 메모리를 평가하기 위해 H3와 유사한 아키텍처를 사용하는데, 시퀀스 모델을 long-term memory 모듈 (LMM)로 대체하였다.

Memory Without Attention

위에서는 MAL을 LMM과 attention의 순차적 조합으로 논의했지만, MAL의 한 가지 간단한 변형은 LMM을 attention이 없는 시퀀스 모델로 취급하는 것이다. 기억의 관점에서 다른 구성 요소가 방해를 받더라도 메모리 시스템의 각 부분이 독립적으로 작동해야 한다. 따라서 LMM은 단기 기억, 즉 attention이 없어도 여전히 강력한 모델이어야 한다. 이 모델을 LMM 또는 Titans(LMM)라고 한다.

4. Architectural Details

모든 블록에서 residual connection을 사용한다. 구현에서 query, key, value를 계산하기 위한 non-linear activation으로 SiLU를 사용하고 $\ell_w$-norm을 사용하여 query와 key를 정규화한다.

저자들은 최근의 linear recurrent model들을 따라, 각 query, key, value projection 뒤에 1D depthwise-separable convolution layer를 통합하였다. 성능에 큰 영향을 미치지는 않지만, 이러한 1D convolution은 성능 향상을 보였고 계산적으로도 효율적이다. 또한 최종 output projection 전에 linear layer를 사용하여 정규화 및 gating을 사용하는 최신 아키텍처를 따른다.

Experiments

  • 데이터셋: FineWeb-Edu에서 샘플링
  • 학습 디테일
    • tokenizer: LLama 2 (vocabulary 크기 = 32,000)
    • 학습 토큰 길이: 4,000
    • optimizer: AdamW
    • learning rate: $4 \times 10^{-4}$ (cosine annealing schedule)
    • batch size: 50만 토큰
    • weight decay: 0.1

1. Language Modeling

다음은 기존의 recurrent 및 Transformer 기반의 모델들과 언어 모델링 및 상식적 추론의 성능을 비교한 표이다.

2. Needle in a Haystack

다음은 Single NIAH (S-NIAH)에 대한 성능 비교 결과이다.

3. BABILong Benchmark

다음은 BABILong 벤치마크에 대한 성능 비교 결과이다.

4. The Effect of Deep Memory

다음은 메모리 깊이에 따른 성능을 비교한 그래프이다.


다음은 메모리 깊이에 따른 토큰 처리량을 비교한 그래프이다.

5. Time Series Forecasting

다음은 long-term forecasting에 대한 성능 비교 결과이다.

6. DNA Modeling

다음은 GenomicsBenchmarks에 대한 성능 비교 결과이다.

7. Efficiency

다음은 학습 처리량을 비교한 그래프이다.

8. Ablation Study

다음은 ablation study 결과이다.