[논문리뷰] FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning - 더 빠르고 더 좋다 !

반응형

논문 요약

 

FlashAttention 2는 FlashAttention의 후속 논문으로, GPU에서 더욱 효율적인 어텐션 연산을 수행하기 위한 최적화 기법들을 제안합니다. 
기존의 Attention 연산과의 차이점을 비교해보며  본 논문의 주요 Contribution을 위주로 한 번 살펴보겠습니다. 

 

Contribution

 

1. 알고리즘 최적화를 통한 비-행렬곱 연산 감소
- FlashAttention-2는 softmax 연산 등에서 불필요한 연산을 제거하고 행렬곱 위주의 연산을 수행하도록 알고리즘을 개선했습니다.
- GPU에서 행렬곱이 훨씬 빠르게 수행되므로 전체 연산 효율이 크게 향상되었습니다.
- 반면 기존의 어텐션 구현에서는 비-행렬곱 연산이 상대적으로 많았습니다.

비행렬곱 연산(non-matrix multiplication operations)은 행렬의 곱셈 연산이 아닌 다른 모든 종류의 연산을 의미합니다. 머신러닝과 딥러닝에서 사용되는 신경망은 주로 행렬의 곱셈 연산으로 이루어져 있지만, 곱셈 이외에도 다양한 연산이 사용됩니다. 비행렬곱 연산의 예시로는 다음과 같은 것들이 있습니다. GPU에서 행렬곱 연산은 매우 최적화되어 있는 반면, 비행렬곱 연산은 상대적으로 최적화가 덜 되어 있습니다. 따라서 비행렬곱 연산을 최소화하고 행렬곱 연산을 최대화하는 것이 GPU 활용 효율을 높이는 데 도움이 됩니다.

 

2. 시퀀스 길이에 대한 병렬화를 통한 GPU 활용도 증대
- FlashAttention은 배치 크기와 헤드 수에 대해서만 병렬 처리를 했지만, FlashAttention-2는 시퀀스 길이 차원으로도 병렬화를 확장했습니다.
- 이를 통해 배치 크기나 헤드 수가 적은 경우에도 GPU를 충분히 활용할 수 있게 되었습니다.
- 기존 구현들은 배치 크기나 헤드 수가 적으면 병렬성이 제한되는 한계가 있었습니다.

3. Warp 간 작업 분배 최적화로 shared memory 접근 감소
- 하나의 thread block 내에서도 warp들이 어떻게 작업을 나누느냐에 따라 shared memory 접근 패턴이 달라집니다.
- FlashAttention-2는 warp 간 통신과 shared memory 읽기/쓰기를 최소화하는 방식으로 작업을 분배했습니다.
- 기존 FlashAttention을 포함한 다른 구현들은 이 부분이 상대적으로 덜 최적화되어 있었습니다.

4. 정량적 성능 향상
- 위와 같은 최적화를 통해 FlashAttention-2는 FlashAttention 대비 최대 2배, 표준 어텐션 구현 대비 최대 10배의 속도 향상을 달성했습니다.
- A100 GPU 기준으로 최대 이론 성능 대비 73%까지 도달했는데, 이는 기존 구현들에 비해 매우 높은 수치입니다.

5. 실제 모델 학습 속도 향상
- GPT 스타일 모델 학습에 FlashAttention-2를 적용한 결과, FlashAttention 대비 최대 1.3배, 기존 방식 대비 2.8배의 학습 속도 향상을 얻었습니다.
- A100 GPU 1개로 최대 225 TFLOPs/s의 연산 속도를 보였는데, 이는 학습에 사용되는 전체 연산량의 72%가 실제로 GPU에서 처리되고 있음을 의미합니다.
- 기존의 방법들로는 이처럼 높은 GPU 활용률을 달성하기 어려웠습니다.

 

그림으로 그려보는 FlashAttention2

 

런타임에서 GPU는 세가지 메모리 접근법을 가지고 있습니다.

SRAM: 실제 연산 코어와 함께 위치한 온칩 메모리입니다. 크기는 제한적이지만(A100 카드에서 약 20MB) 극도로 빠릅니다(총 19TB/s 대역폭).
HBM: 오프칩이지만 카드 내부에 있는 메모리로, GPU 내부에 있지만 코어와 함께 위치하지는 않습니다. A100은 40GB의 HBM을 가지고 있지만, 1.5TB/s의 대역폭만 가집니다.
DRAM: 전통적인 CPU RAM입니다. TB 단위로 가질 수 있지만, 약 12.8GB/s 대역폭만 얻을 수 있어 너무 느립니다.

 

 

GPU에서 가장 작은 처리 단위를 "스레드(Thread)"라고 합니다. 

스레드는 덧셈, 뺄셈, 곱셈, 나눗셈과 같은 간단한 산술 연산을 수행하기에 적합합니다.

 

일반적인 GPU 카드에는 수천 개의 CUDA 코어가 있으며, 각 코어는 여러 개의 스레드를 실행할 수 있습니다. 예를 들어, NVIDIA H100은 16,896개의 CUDA 코어를 가지고 있습니다. 스레드는 스레드 블록으로 그룹화되며, 각 블록은 동일한 연산을 실행합니다.

 

예를 들어, 일반적인 NVIDIA GPU 카드는 스레드 블록당 최대 1,024개의 스레드를 가질 수 있습니다. 각 스레드 블록은 빠른 공유 메모리(SRAM)에 접근할 수 있습니다. 이 메모리는 작지만 빠릅니다!

 

대부분의 고성능 GPU는 10MB에서 40MB 사이의 공유 메모리를 가지고 있습니다. 모든 스레드 블록은 또한 큰 전역 메모리를 공유할 수 있습니다. 최신 GPU 대부분은 더 빠른 High Bandwidth Memory(HBM)에 접근할 수 있습니다. HBM은 SRAM보다 1,000배 이상 클 수 있습니다. HBM의 데이터 접근 속도는 빠르지만 SRAM보다는 느립니다.

 

 

 

데이터가 메모리에서 이동하는 방식을 이해하는 것은 더 나은 알고리즘을 작성하는 데 매우 중요합니다.

 

예를 들어, 어텐션 레이어에서는 쿼리와 키 사이의 텐서 곱을 계산해야 합니다.

 

$ S = QK^T $ 이 연산은 스레드 블록에 분산되고, 결과 변수 S는 전역 메모리(또는 사용 가능한 경우 HBM)에 기록됩니다. 이 작업이 완료되면, S 행렬을 다시 스레드로 가져와서 softmax 변환을 계산해야 합니다. 

 

$Attention = Softmax(S)$

 

그리고 다시, 결과 행렬을 전역 메모리로 이동시켜야 합니다. 행렬은 스레드와 전역 메모리 사이에서 앞뒤로 이동합니다. 왜냐하면 연산을 스레드 블록에 격리시키고 SRAM을 사용하여 중간 행렬을 캐시할 수 있는 논리적인 방법이 없기 때문입니다.

요즘 일반적으로 사용되는 한 가지 전략은 S 행렬의 연산을 더 작은 행렬로 분할하여 각 작은 연산을 스레드 블록에 격리시키는 것입니다.

 

아래는 일반적으로 HBM에서 Attention을 연산하는 프로세스와 FlashAttention 기법을 사용하여 Attention을 계산하는 프로세스를 비교합니다. 

반응형