AI/ML System

[논문 리뷰] SQUEEZED ATTENTION: Accelerating Long Context Length LLM Inference

민사민서 2025. 5. 26. 17:04
반응형

Problem Definition

- long context length 어플리케이션에서 inference cost는 seqlen에 비례하여 선형적으로 증가한다

- LLM 추론이 상당한 computational resource 요구 / memory capacity, bandwidth 요구 => 거대한 사이즈의 KV cache 때문

- in-context learning, document QA, code generation 등의 어플리케이션에서는 입력 컨텍스트의 상당 부분이 고정되어있다

- 모델에 들어가는 것: "fixed context" => 연속된 프롬프트들에서 재사용 가능 + "user input" => online으로 들어오는 유저 요청들

- 이 논문에서는 fixed context(코드, 문서 등등)가 추론 전에 접근 가능한 상황을 타겟팅

high-level overview of KV cache clustering (sparsification)

SQUEEZED ATTENTION에 대한 간단한 preview:

- offline에서 fixed context에 대한 키들을 의미적 유사도(cosine similarity) 기반 클러스터링한다

- online에서는 query token과 centroid를 비교하여, 진짜 중요한 일부 key들만 retrieve하여 exact attention을 수행한다

- centroid lookup 및 retrieval 모두 GPU에서, attention head마다 서로 다른 양의 key 로드할 수 있다

 

Related Works & How is it different?

<Long-Context LLM>

- LLM이 발전하면서 100K, 1M token 길이의 context length까지도 지원하고 있다

- 하지만 KV cache가 critical bottleneck으로 등장했다, 추론 시 memory usage/latency에 큰 영향을 미침

- long-context inference에서 KV cache size를 효율적으로 다루는 기법들이 제시됨

   * quantization, shared KV cache across tokens/layers , token pruning, KV cache sparsification

 

<KV cache sparsification - KV cache eviction>

- token간 중요도를 부여하여 덜 중요한 토큰은 지움으로써 KV cache 사이즈를 줄이다

- attention score contribution 기반, generation 도중 persistent attention pattern 기반, token entropy 기반 등

- 한계 = generation 도중 token importance가 dynamic하게 변한다면? 미리 구해둔 importance가 미래에 들어올 쿼리들을 잘 반영하지 못한다면? 이미 discarded된 token이 나중에 중요해진다면?   => Squeezed Attention에서는 evict 하지 않는다

 

<KV cache sparsification - sparsely loading KV cache>

- QUEST: 연속된 KV cache entries들을 클러스터링, generate 시 동적으로 가장 관련있는 클러스터를 retrieve

- FMM(Fast Multipole Method) 기반: 연속적인 토큰들끼리 cluster를 하는데, 오래된 토큰일수록 coarser-grained하게

- 한계: 물리적인 proximity 기반 clustering, token 간 semantic proximity를 반영하진 못한다, 만약 관련 있는 토큰들이 멀리 떨어져있다면?   => Squeezed attention에서는 embedding similarity 기반 클러스터링

 

- PQCache / RetrievalAttention: prompt를 dense attention으로 처리해 KV cache 생성, CPU로 복사 후 PQ codebook 또는 인덱스 구축, PQ-based 벡터 탐색 또는 KNN vector search 

- 한계: prefill stage에선 사용하지 못한다 (인덱스,코드북 등 빌드해둬야 하므로), time-to-first-token latency가 크다

 

Offline phase: Clustering Keys

fixed context keys를 받아서 cosine similarity(의미적 유사도) 기반 K-means clustering한다

semantic-based (not physical-based) clustering이어서 key token 접근이 비연속적이지 않냐?

  - single head의 KV cache token은 대부분 bf16에서 256B보다 크다 (h_dim > 128 normally)

  - sparsely loading keys 과정에서 bandwidth underutilize 가 발생할 걱정은 없다

centroid를 n-level 로 hierarchical하게 clustering 할 수도 있다

 

Online: Query-Aware Key Retrieval

- query token q에 대한 cluster i의 importance S_i는 위 식과 같이 구한다

- cluster 내 N_j개의 key들의 attention score 평균을 cluster centroid로 근사함

- S_i가 desired threshold T 보다 높으면 통과 , retrieve 한다

- softmax estimate S_i를 사용한 이유는 0-1 범위로 정규화, global threshold를 적용할 수 있기 때문

   * 어떤 attention head는 더 고른 분포를 가져 많은 important key를 retrieve할 수 있고, 다른 attention head는 skewed 분포를 가져 적은 important key만 retrieve할 수 있다

 

Hierarchical Centroid Lookup

fine-grained centroid lookup을 통한 accuracy 개선 vs. coarse-grained centroid를 통한 효율성 개선

=> centroid를 multi-level로 두어도 된다

- offline clustering 단계에서 각 level에서 K-means clustering을 하고, online okkup 단계에서도 각 level에서 query q와 centroid 간의 score를 구해 threshold 적용

 

Complexity Analysis

- c와 c'는 각 level에서 유지할 centroid 개수 (당연히 c' < c 일 것)

- k << L 은 실제로 retrieve하는 key의 개수 , logL 은 유지할 hierarchical level(depth)

 

Implementation Details

<centroid lookup basic idea>

kernel은 기본적으로 3 단계의 pass를 거친다

- initial pass: key centroid 전체를 순회하며 Eq1 의 분모를 계산한다

- second pass: centroid 전체를 순회하며 S_i 를 (계산해둔 분모로) 계산한다

- third pass: target threshold T와 비교하여 load할 후보군을 정한다

 

<Prefill Stage>

- multiple query token이 한 번에 주어진다

- FlashAttention-2 구현을 참고하여 workload를 query sequence length dim 기준 분할한다

- 각 key centroid가 each query token에서 가지는 S_i score를 구한 뒤 평균내서 average importance score를 구한다

 

<Generation Stage>

- single query token만 주어져서 병렬화 적용 가능한 부분이 그나마 different attention head 간에?

- centroid lookup을 3 pass를 2 pass로 압축하여 가속화했다

- pass 1에서 분모와 분자를 동시에 계산 및 저장해둔다음 pass 2 에서 분자 와 분모 * Threshold를 비교

 

<Sparse attention with retrieved keys>

- FlashAttention-2의 기법을 사용, sequence length / head 단위에서 작업을 분할한다

- head 간에 retireve해야 할 key tokens 수가 엄청나게 다르다면? => GPU의 SM이 처리할 수 있는 fixed number elements 개수의 배수가 되도록 keys & values를 split한다 => 해당 head의 work는 GPU의 더 많은 SM들에 의해 병렬화될 것

 

New Benchmark - PreFixQA

 

- LongBench, RULER 등의 기존 벤치마크는 long-context input & single-question 조합이었다

- 즉 offline preprocessing step이 매 qa sample마다 이루어져야 했다

- fixed context - multiple qa samples이 가능한 벤치마크 제작 => Prefix-Fixed QA (PreFixQA)

 

Long Documents Collection

- arXiv에서 47개의 논문 수집, 20페이지 정도의 충분한 분량, 2024년도, 여러 도메인

 

QA Generation

- Llama-3 활용하여 multi-step generation & filtering 과정을 거쳤다

- 균일한 질문 생성을 위해 multiple chunk로 문서를 나누어 각각에 대한 추가적인 QA를 GPT-4-Turbo 사용해 만들었다

- GPT-4-Turbo를 사용해 반복적으로 QA pair를 생성하여 일관성 및 정확도 올렸고, llm보도 judging하게 하여

- 1127개의 QA pair (각 document마다 24개)

=> 더 좋은 방법이 있지 않았을까 하는 생각? (ex. openreview에서 해당 논문 가지고 discussion한 내용들 활용, 저작권 문제되려나?)

 

Evaluation

full KV cache baseline과 비슷한 정확도를 보이더라, QUEST baseline보다 더 좋더라 (semantic-based clustering의 효과)
RULER dataset에서, full KV cache baseline에 대적할만한 accuracy를 보이더라
PreFix QA에 대한 것, full KV cache와 비슷한 성능을 내더라
clustering을 offline에서 한 채로 비교, prefill/generation 에서 full KV cache lookup을 할 필요가 없으므로 더 빨라지더라

 

(총평)

fixed context가 사전에 주어진 상황에서는 효율적일 수 있다, 일반적인 long-context 상황에서 이 접근을 어떻게 적용 가능할까?

코드를 읽어봐야겠지만, retrieval 과정에서 어떻게 타겟 centroid/key values들을 찾아가게 구현해뒀을까 궁금

benchmark 제작 과정에서의 약간의 idea => 만약 long-context multiple QA benchmark가 부족하다면 꽤 괜찮은 방법?

berkeley BAIR 랩 논문을 읽을 때마다 느끼는 건데, 참 글을 읽기 쉽게 쓰더라고요

반응형