世界を動かす技術を、日本語で。

注釈付きトランスフォーマー (2022)

概要

  • Transformer アーキテクチャの詳細な解説と実装例を提供
  • エンコーダ・デコーダ構造、自己注意機構、残差接続、正規化などの主要要素を説明
  • PyTorch によるコード断片とその設計意図の注釈付き解説
  • 逐次計算の削減 と、従来手法との比較を通じた利点の整理
  • 注意機構のマスク処理や、デコーダの出力制御方法も具体的に記述

Attention is All You Need v2022 解説

Transformerの背景と概要

  • Transformer は、従来のRNNやCNNベースのモデルと異なり、 自己注意機構のみ で系列変換を実現
  • 逐次計算の削減 を目指し、全位置間の依存関係を効率的に学習
  • Multi-Head Attention により、単一の注意重みの平均化による情報損失を緩和
  • 自己注意 は、機械読解・要約・言語推論など多様なNLPタスクで有効
  • Transformer は、系列整列型RNNやConvS2Sを使わず、自己注意のみで入出力表現を計算する初のモデル

モデルアーキテクチャ

  • Encoder-Decoder構造 を採用
    • Encoder :入力系列を連続表現へ変換
    • Decoder :エンコーダの出力を参照しつつ、出力系列を逐次生成
  • EncoderDecoderクラス
    • エンコーダ・デコーダ・埋め込み・出力生成器を統合
    • forwardメソッドでマスク付き系列を処理
  • Generatorクラス
    • 線形変換+softmaxで出力語彙分布を生成

エンコーダ・デコーダスタック

  • Encoder
    • 同一構造の層(N=6) を積層
    • 各層は 自己注意+位置ごとの全結合層 から構成
    • 残差接続+LayerNorm を各サブレイヤに適用
    • 入力・出力次元は常に d_model=512 で統一
  • LayerNormクラス
    • 入力系列を正規化し、学習可能なスケール・バイアスパラメータを持つ
  • SublayerConnectionクラス
    • サブレイヤ出力に ドロップアウト+残差 を適用し、LayerNormで正規化
  • EncoderLayerクラス
    • 自己注意機構+フィードフォワード層 を1層にまとめる
  • Decoder
    • N=6層 の積層構造
    • 各層に 自己注意・エンコーダ出力への注意・フィードフォワード の3サブレイヤ
    • 残差接続・LayerNormはEncoderと同様
  • DecoderLayerクラス
    • 自己注意・ソース注意・フィードフォワード を統合
  • マスク処理
    • subsequent_mask関数 で、デコーダが未来の単語を参照しないように制御
    • 出力埋め込みのずらし+マスクで、i番目の予測はi未満の出力のみ参照可能

注意機構(Attention)

  • Attention関数 は、クエリ・キー・バリューの組み合わせから重み付き出力を計算
  • Scaled Dot-Product Attention を採用
    • クエリとキーの内積を$\sqrt{d_k}$でスケーリングし、softmaxで重み化
    • バリューに重みをかけて出力を得る
  • 実装では、複数クエリをまとめてバッチ処理可能

Transformerの利点と従来手法比較

  • RNNやConvS2S は、系列長に比例・対数的な計算コストがかかる
  • Transformer は、全位置間依存性を 定数回の演算 で実現
  • 長距離依存性の学習が容易 で、並列計算にも適する設計
  • Multi-Head Attention により、異なる特徴を同時に抽出可能

実装上の工夫と補助関数

  • PyTorch ベースの実装例を随所に挿入
  • clones関数 で同一層の複製を簡単に生成
  • DummyOptimizer/DummyScheduler など、ノートブック実行用のダミークラスも用意
  • show_example関数 などで、インタラクティブな例示や可視化も可能

まとめ

  • Transformer は、自己注意機構とエンコーダ・デコーダ構造を基盤とした新世代の系列変換モデル
  • 効率性・並列性・長距離依存性 の観点で従来手法を上回る
  • 実装の各要素 (残差接続、LayerNorm、マスク処理など)も明快に設計されている
  • PyTorch での実装例は、学習や応用の出発点として有用

Hackerたちの意見

すごい!これ、本当に良くできてるね!自分もTransformerベースの音声モデルについて調べてたけど、細部までしっかり作られてる。アテンションの概念自体、非線形だから初心者にはちょっと分かりづらいけど、これがすごくよく説明されてる。

これは長い間人気のある記事だね!

アテンションの概念自体、かなり直感的じゃない それが実はカーネルスムージングの再定義に過ぎないって気づくと、すごく直感的になるよ。トランスフォーマーは基本的に積み重なったカーネルを学習してると考えられるから、ガウス過程に驚くほど近い関係になるんだ。

あ、これ自分が書いたやつだ。久しぶりだね。機械翻訳やパースに関わる幸運なブレイクがあって、世紀の重要な発明が自分のニッチな分野で起こったんだ。コードと機械学習の交差点に興味があるよ。もしそれが気になるなら、他にも興味深い記事があるよ。* CUDAについて考える: http://github.com/srush/gpu-puzzles * 有害なテンソル: https://nlp.seas.harvard.edu/NamedTensor * SVGの微分: https://srush.github.io/DiffRast/ * 注釈付きS4: https://srush.github.io/annotated-s4/ 最近、業界に戻ったから、しばらく書く機会がなかったんだ。

これすごいね、リンクと記事ありがとう!

実はこれ、オリジナルじゃなくて現代版だって気づいた。だから、これを書き直したオースティン・ファン、スラージ・スブラマニアン、ジョナサン・サム、カリッド・アルムバラク、ステラ・ビーダーマンに感謝だね。

GPUパズル、めっちゃ楽しかった!全部終わった後、もっとあったらいいなって思ったよ。その過程でたくさん学んだし。

アテンションの部分に入るとき、みんなが「キー・クエリ・バリュー」って表現するのをやめてほしいな。トランスフォーマーにおけるそれぞれの機能に特別な意味はないから。KQV行列自体は、入力ベクトルに学習した重みを掛け算して計算されるんだけど、最終的には正しい結果に収束するランダムな行列なんだ。つまり、最終結果が12になるのに26でも34でも関係ないってこと。トランスフォーマーが機能するのは、多次元性のおかげで、行列同士を掛け算してるからで、ベクトルのドット積を計算するのとは違うんだ。行列の掛け算は実質的にドット積の合計だから、トランスフォーマー全体を幅広の単層パーセプトロンのシーケンスとして表現できる(ただし、ゼロがたくさんあるけど)。でも数学的には同じことをしてるんだよね。

それについては反対だな。K、Q、Vはアテンション計算の中でそれぞれ異なる役割を持ってるから。特にデコード(推論中の次のトークン計算で、プロンプトを処理する初期のプレフィル段階の後に続く)を考えると、進行中のトークンに関連するQベクトルが1つあって、計算済みのすべてのトークンを表す複数のKとVベクトルがあるんだ。

「トランスフォーマー全体を広い単層パーセプトロンのシーケンスとして表現できる」 これは正しくないよ、またアテンションのせいでね。クラシックなパーセプトロンは静的な重みを持っていて、それは入力じゃない。アテンションを計算するために同じ数学的関数を使うことはできるけど、静的な重みはない。アテンションスコアが一方にあって、Vマトリックスがもう一方にあるんだ。実際、静的な重みを持っていて、2つの入力を直接掛け算できないパーセプトロンの集まりがアテンションメカニズムを「発見」することができるのか、ちょっと疑問に思うよ。MLPは一般的な関数近似器だから、十分に大きな数があれば、近づくことはできるのかな?