Flash Attention - Shinichi0713/LLM-fundamental-study GitHub Wiki
Purpose
フラッシュアテンションの仕組みについて理解、合わせて実装法、効果について確認する
Background
TransformerがNNのデファクトスタンダードとなっている。
Transformerを支える技術がScaled Dot-Product Attention (SDPA)。
SDPAには優れた機構であると同時に、長い系列を扱う時に駅山稜が膨大となるという欠点がある。
上記課題に対して、アテンションに制限を加えて疎にしたり、ソフトマックス部分を置き換えることなどにより、計算量を下げる様々な方法が提案されてきた。
Flash Attentionは計算の内容自体は元のSDPAと同じ。だけど、計算アルゴリズムとGPUメモリへのあくせす方法を工夫することで、GPUメモリ使用量を系列長に対して線形に抑え、計算速度も速くなるという手法。
ちなみに、Hugging Faceには以下のように記載されている。
In practice, there is currently absolutely no reason to not use Flash Attention if available. The algorithm gives mathematically the same outputs, and is both faster and more memory-efficient. ということで、使える状況であれば使うべき。(出力に変化はないので)
attention machanizm
Flash Attentionは、長いSequenceでのTransformer学習ができるようにするという目的で提案された手法で、従来のAttention方法での以下のような問題を解決しようししています。
長いSequenceでの学習が難しい
長い処理のためにBatchSizeを減らすと学習時間が長くなる
指摘しているのは、Qeury x Key のマトリクス計算の部分で、N x N サイズの計算の際、GPUでのデータやりとりが上記の問題の原因になっていると説明
着眼点:IO Awarenessの観点より、GPUとSRAM間のやり取りを加速化することで、計算効率の向上と、高速化を実現する技術。
処理の流れは以下の絵が分かり易い。
KeyとQueryを対応するブロックを取り出しながら、積算をSRMで算出させるイメージ。
IO Awareness(Input/Output Awareness)は、特にコンピュータサイエンスや機械学習の分野で、データの入出力に関連する効率性を考慮した設計やアルゴリズム。 たとえば、GPUとメモリ間のデータ転送を最適化することで、計算速度を向上させる技術が含まれます。
Tiling:データや画面を効率的に分割して処理する技術を指します。例えば、グラフィックス処理やメモリ管理において、データを小さなブロック(タイル)に分割して処理することで、効率を向上させる手法が「Tiling」として知られています。
Flash Attentionを使うことの出来る条件
- データ形式として、ffloat16 and bfloat16を使うこと。
- Cuda Device Propertyが対象であること。
- Maskを使わないこと。(Triangular Matrixは可能)
効果
効果は参考サイトの結果を引用したもの。
Normal Attention VS Flash Attention
Attention Type | VRAM at Start | Sequence Length | Output Text Length | Computation Time (sec) | VRAM at End | VRAM Difference |
---|---|---|---|---|---|---|
Normal Attention | 12.62 GB | 1024 | 506 | 400.2 | 13.06 GB | 0.44 GB |
Flash Attention | 12.62 GB | 1024 | 506 | 333.5 | 12.93 GB | 0.31 GB |
出力結果に違いはないが、計算時間や消費VRAM量には改善の効果があることが確認された。
Key-Valueキャッシュ(KVキャッシュ)とは、データを「キー」と「バリュー(値)」のペアとして保存し、効率的にデータを検索・取得する仕組みです。以下のような特徴があります:キー(Key): データを識別するための一意のラベル。バリュー(Value): キーに対応する実際のデータや情報。この仕組みは、特に以下のような場面で役立ちます:高速なデータアクセス: 必要なデータを迅速に取得できるため、アプリケーションのパフォーマンスが向上します。キャッシュとしての利用: 頻繁にアクセスされるデータを保存し、データベースへの負荷を軽減します。
要点
Transfomerの計算コストとメモリ使用量が大きい(系列長の2乗に比例して増加)という課題を解決するための技術。
従来手法ではクエリとキーの全てのペアの相関を計算していた点に注目して効率化を図った。
工夫点
- メモリ効率の向上 標準のアテンション計算では、クエリ、キー、バリューの全てのペア間の相関を一度に計算し、これを一時的にメモリに保持します。しかし、Flash Attentionでは、このプロセスを分割して行い、メモリ使用量を削減します。
- 分割計算 大規模なシーケンスを小さなチャンクに分割し、各チャンクごとにアテンションを計算します。これにより、メモリに保持する必要のあるデータ量が減り、計算がより効率的になります。
- 並列計算 各チャンクに対するアテンション計算は独立して行うことができるため、並列処理が容易になります。これにより、計算速度が向上します。
参考
flash attention accelerate generative ai revolution
https://www.youtube.com/watch?v=gBMO1JZav44
Flash Attentionを使ってLLMの推論を高速・軽量化できるか?
https://qiita.com/jovyan/items/11deb9d4601e4705a60d