[논문리뷰] FlashAttention: Fast and Memory-Efficient Exact Attentionwith IO-Awareness
이전 attention approximation 관련 논문들은 compute complexity를 줄이는데 집중했다.
ex) block-sparse attention
https://velog.io/@nawnoes/sparse-attention
https://huggingface.co/blog/big-bird
이래서는 wall-clock speedup은 없다. 단순히 FLOP 수만 줄일 뿐이다. runtime은 메모리 접근 시간(IO)에 더 관련이 있다
FlashAttention은 tiling이라는 기법을 통해 IO-aware한 attention 연산을 수행한다
= GPU의 메모리 HBM과 on-chip SRAM 간의 read/write IO 수를 줄이는 더 빠르고 메모리 효율적인 방법을 제시한다.
방식은 간단하다. GPU HBM에서 query, key, value를 관리하고, 적절한 양만(블록 단위) on-chip SRAM에 올려 계산한다
memory-bound operation인 attention 메커니즘에서 memory access의 bottleneck을 줄였다.
동작 방식은 위의 사진 가운데와 같은데, 아래에서 자세히 설명하겠다.
마찬가지로 예전 pytorch의 naive한 implementation과 다르게
1) QK^T multiplication - mask - softmax - dropout - V multiplication 을 따로따로 하지 않고, block 단위로 reduction 하면서 결과를 계속 누적해나간다 => 즉 하나의 kernel로 fuse 했다
2) 이제 N^2 공간을 차지하는 중간 결과 matrix를 HBM에 내려 저장했다 다시 불러올 필요가 없다
3) 마찬가지로 backward에서도 tiling을 적용하기 때문에 forward 방향에서의 모든 연산 결과를 저장해둘 필요가 없다, 필요 시 recompute 하면 된다
* gradient checkpointing(https://only-wanna.tistory.com/entry/Gradient-checkpointing%EC%9D%B4%EB%9E%80)를 IO-aware하게 좀 변형한 느낌
IO-aware한 연산을 진행하므로
- 당연히 더 빨리 동작할거고 (faster wall-clock time)
- 메모리도 덜 사용할거다 (sequence length N에 대해 linear하게 증가함, 기존엔 quadratic)
Background
1. GPU 메모리 계층구조
40-80GB의 HBM gpu memory - 192KB의 on-chip SRAM 메모리
당연히 HBM보다 on-chip SRAM이 bandwidth는 10배 넘게 빠르다, 따라서 연산은 대부분 HBM 메모리 접근에 bound된다
2. performance characteristic
compute-bound => arithmetic ops 수에 수행시간이 결정된다, large inner dim의 matmul, convolution 연산
memory-bound => memory 접근 수에 의해 결정된다, elementwise 연산(activation, dropout), reduction (sum, softmax, layer norm 등등)
3. 일반적인 attention 연산
S와 P(중간 계산 결과)를 HBM에 저장한다, O(N^2)의 메모리를 차지한다
FlashAttention Forward Algorithm
basic idea
- Q, K, V를 제한된 공간의 on-chip에 올려 계산하자, block 단위로 쪼개 메모리 제약 사항을 만족하면 되겠지
- 근데 생각해보면 softmax를 하려면 rowwise max, rowwise summation을 해야 하는데 (한 query 기준 N개의 token에 대한 attension 점수를 normalize) blockwise로 쪼개면 못하지 않냐
- reduction으로 intermediate value를 업데이트해나가면 된다 (** 이게 핵심 로직)
tiling에서 softmax가 어떻게 처리되는가
softmax는 위와 같이 처리된다. rowwise max를 m(x), rowwise sum을 l(x)라고 하자.
block 단위로 쪼개어 계산한다면, 일단 block 단위로 rowwise max, rowwise sum을 구해두고, 나중에 합칠 때 m_global(x)는 max를 취하고, l_global(x)는 가중합을 하면 된다
최종 forward logic은 다음과 같다
1) B_c*d는 K, V의 블록 크기이고, B_r*d는 Q, O의 블록 크기이다
* M은 메모리 크기이다, 4d로 나눈 이유는 Q_i block, K_j block, V_j block, O_i block 4개를 on-chip SRAM에 올려야 하므로
* B_r에서 min이 적용된 이유는 B_c * B_r 이 O(M)을 넘어가면 안되기 때문이다 (blockwise QK^T 연산 중간결과)
2) (K_j, V_j) 블록 쌍을 on-chip으로 로드하고, 거기서 Q의 여러 블록을 traverse하며 대응되는 O_i를 업데이트한다
* K,V outer loop / Q, O inner loop
* m_i 는 전역적으로 관리되는 rowwise max value, l_i는 전역적으로 관리되는 rowwise sum => 매번 업데이트된다
* O_i 또한 outer loop j가 진행됨에 따라 계속 업데이트된다. 계산 방식은 기존 l_i의 효과를 제거하고 이번 루프의 partial output을 더하고 l_i_new로 나누는 방식이다
FlashAttention Forward Algorithm Analysis
O(N^2d)의 FLOPS를 가진다
O(N)의 additional memory가 필요하다
* N^2 형태가 아니라 Q,K,V, 그리고 O 모두 Nxd 형태로 가지고 있으므로
HBM access는 기존 방식 Θ(Nd + N^2) 에서 Θ(N^2d^2M-1) 로 바뀌었다
* 기존 방식에서 dorminant한 항은 N^2 이었다
* 일반적으로 d(64-128)의 제곱보다 M(around 100KB)이 훨씬 크기 때문에 HBM access는 기존 구현 방식보다 훨씬 줄었다
* 계산 방식은 직관적으로 생각해보면 ND/M 만큼의 outer loop를 돌면서 (M에 채울 수 있을 만큼 block을 쪼개 block 단위로 처리하면서) Q의 Nd elements를 로드하기 때문에 (Nd) * (Nd/M) 으로 도출된다
또한 어떤 알고리즘도 N^2d^2M^-1 보다 적은 HBM access로 구현할 수 없다
* 이 flashattention이 HBM access 에 있어서는 lower bound 이다, 논문에선 귀류법으로 증명함
왼쪽 표에 따르면 GFLOPs는 늘었지만 HBM access 수를 줄이니 runtime이 dramatic하게 감소됨을 알 수 있다
=> memory access가 main bottleneck이었다
block size를 키울수록 HBM access가 줄어 runtime이 줄지만, 어느 순간 부터는 다른 오버헤드(연산 등)로 감소하지 않더라
FlashAttention Backward Algorithm
forward랑 비슷한 방식으로 진행된다
FlashAttention in Block-Sparse Attention
- sparse 패턴에 따라 mask를 적용한다
알고리즘은 똑같다, 단지 zero blocks 들은 생략 가능하다, HBM access 수도 그만큼 줄겠지 (sparse 비율만큼)
Experiments
학습 속도, 더 긴 시퀀스가 가능해짐에 따라 높은 퀄리티, efficient memory footprint(linear scaling)에 따른 각종 benchmark에서 우수한 성능 보임
BERT, GPT-2 모델 훈련 시 비슷한 성능에 speedup은 3x 이상 되었으며
long-range arena 벤치마크, Path-X/Path-256 벤치마크, documant 분류 벤치마크에서 다른 attention 기법들에 비해 우수한 성능을 보였으며
runtime/memory-efficiency 때문에 GPT-2에서 4x context length를 키웠음에도 더 빠른 학습 시간을 보였다
등등..
Potential Extensions?
- Multi-GPU Attention: 확실히 이 논문은 single gpu 환경에서의 IO-aware algorithm을 제시하였다. gpu 간 통신까지 고려해 더 발전시킬 수 있을 것이다
- Sparse MLP Layers: MLP는 기본적으로 compute-bound지만 sparse한 곳에서는 mem-bound이다. 여기 적용 가능할지도
- kernel machine learning
저자가 언급한 한계로는 CUDA 수준으로 구현하느라 PyTorch 같은 high-level로는 구현하지 못했다, 다른 DL 분야의 적용이나 Multi-GPU로의 적용은 아직 하지 못했다 라고 하네요
experiment에 사용된 모델들 보니까 새삼 오래 전에 나왔구나 하는 생각이 드네요
tiling 아주 멋진 개념이죠