世界を動かす技術を、日本語で。

ニューラルネットをC言語にコンパイルして高速化する

概要

  • ニューラルネットワーク(NN)で論理ゲートを活性化関数として使用し、Conway’s Game of Life のカーネル関数を学習
  • 学習済みNNから論理回路を抽出し、C言語にコンパイルして推論速度を大幅に向上
  • 元のNNと抽出したCプログラムを比較し、1,744倍の高速化を達成
  • JAXやDifferentiable Logic Gate Networks(DLGN)などの技術的背景解説
  • 論理回路学習の工夫や課題、今後の実験案を紹介

ニューラルネットワークから論理回路へ:Conway’s Game of Life のカーネル学習

  • ニューラルネットワーク(NN)で 論理ゲート を活性化関数として使用する新手法
  • Conway’s Game of Life の 3×3カーネル関数 をNNに学習させる実験
  • 学習後、NNから 論理回路 を抽出し、 C言語 にコンパイルする手法を開発
  • 抽出した回路は 不要なゲートを最適化 して削除し、効率化を実現
  • 元のNN推論と比較し、 1,744倍の高速化 を達成したベンチマーク結果
  • 実装コードは Python/JAX(約354行)C(約331行) で公開

研究背景と着想

  • Google Self Organising Systemsグループの Differentiable Logic Cellular Automata 論文に着目
  • Cellular Automata(CA) :各セルが局所ルールで状態遷移するグリッド
    • 代表例: Conway’s Game of LifeRule 110
    • 局所ルール(カーネル)により複雑な挙動が生まれる
  • Neural Cellular Automata(NCA) :カーネル関数をNNで置き換えたCA
  • Deep Differentiable Logic Gate Networks(DLGN) :重みが0/1固定、各ノードが2入力、活性化関数として16種論理ゲートの線形結合を学習
    • 固定ワイヤにより、どの論理ゲートを使うかを学習する構造

Conway’s Game of Life のルールと論理回路化

  • 2Dグリッド上で 各セルの状態(生死) を8近傍セルを見て決定
    • ルール1:3つの生きた近傍セルがあれば 生存
    • ルール2:2つの生きた近傍セル+自身が生存なら 生存維持
    • それ以外は
  • 9ビット入力(中心+8近傍)→1ビット出力(次状態)となる論理回路設計が課題
  • 近傍セルの 生存数カウント が回路設計の難所
    • XORやANDゲートの組み合わせで2値・4値・8値のカウント回路構築がヒント

JAXによる実装と特徴

  • JAX はPython向けのMLフレームワークで、numpy互換API+自動微分・並列化・JITコンパイル機能を持つ
    • grad :自動逆伝播微分
    • vmap :バッチ並列化
    • jit :GPU対応のJITコンパイル
  • Optax (最適化アルゴリズム)、 Flax (NNライブラリ)などJAXエコシステムも活用
  • 乱数生成 が再現性を持ち、デバッグが容易

論理ゲートの連続緩和と学習

  • 論理ゲートの 離散的な動作 を、微分可能な 連続関数 (例:AND→a*b)で近似
  • 16種類の2入力論理ゲート全てに対し 連続緩和 を定義
  • 各ゲートの重みを softmax で正規化し、NNとして学習可能に
  • 学習後は argmax で最も強いゲートのみを選択し、最終的に 論理回路 として抽出

学習時の工夫・課題

  • 通常のNN(relu活性化)では 重み初期化 が正規分布中心0で良好に収束
  • DLGNでは ワイヤ重み を0/1固定にする必要性
    • ワイヤ重みを連続値やsoftmaxで学習しようとしたが、収束せず断念
    • 固定ワイヤにより 勾配伝播 が安定し、論理ゲート学習が可能に

今後の展望・実験案

  • 他のセルオートマトンや 流体シミュレーション への応用
  • Reintegration Trackingなど 複雑なカーネル回路 の自動発見
  • 開発中に 開発日誌 をつけることで、進捗管理やデバッグが容易になった知見

まとめ

  • ニューラルネットワークで 論理回路を学習・抽出 する手法の有効性を実証
  • C言語への変換・最適化 により、従来NN推論に比べ大幅な高速化を達成
  • JAXによる実装 や連続緩和技術、論理回路抽出の工夫がポイント
  • 今後も 多様な自動回路設計 への展開が期待される

Hackerたちの意見

微分可能論理ゲートネットワークはめっちゃ面白いね。でも、最初から配線が固定されてるのはあんまり好きじゃないな。学習可能な配線についてちょっと粗い研究をしたことがあるけど、4ビットの足し算すら学習できなかったよ。

みんなの楽しみを奪っちゃうけど、特許も取られてるよ :)

「ウェイトアグノスティックニューラルネットワーク」の技術もここで使えると思う。NEATのバリアントを使ってるはず。これでトポロジーや配線を学習できるようになるけど、実際にはかなり遅いかもしれないし、剪定された最適化されたDLGNとあまり変わらないかもね。

ハハ!このアイデアに2年間取り組んできたけど、最近スケーラブルに配線を学習する方法を見つけたよ(入力ビットの数も出力ビットの数も自由に設定できる)。このアイデアに夢中な人と話したいな。

最近HNでDLGAについて読んで、すぐに「これは面白い意見だな」と思ったんだけど、論文から実装するのは難しかった。うまく動かせて、ドキュメントも作ってくれて嬉しい!ありがとう!

ここに作者がいます。質問があればどんどん聞いてね。

この結果は驚きだった?

ゲーム・オブ・ライフのルールにいくつか間違いがあるよ。過密ルールを見落としてるね:生きてるセルが3つ以上の生きた隣接セルを持つと死ぬんだ。 Nit: > 「死んでるセルは死んだまま」っていう厳しい第3のルールがあると思うけど、その言い回しは正確じゃないよ。死んでるセルがずっと死んでたら、第一のルールが機能しないからね。正直、その文は流れにあまり貢献してないと思う。

すごく興味深い仕事の素晴らしいまとめをありがとう!バイナリネットワークや微分可能回路がAIの未来で大きな役割を果たすと思う?今の密なベクトル表現は情報をエンコードするには劣った方法になるんじゃないかって、ずっと思ってるんだ。

ここで面白いのは、単純な移植じゃないってこと。JAXは実装しているアーキテクチャに対してすでにかなり速い。ポイントは、パススルーだけを行うノードを取り除いてネットワークを大幅に縮小し、64ビットを一度にビット演算で計算を大規模に並列化すること。だからこの驚異的なスピードアップが実現できるんだ。

uint64_tのセルをattribute((vector_size(32)))に置き換えて、march=nativeでビルドすると、ビット演算は前と全く同じように動くけど、x64マシンのベクトルユニットが活性化されるよ。いいブログ記事だね、ありがとう!

楽しんでくれてよかった、アドバイスありがとう!

これは限界ケースとしてすごく面白いね。いつも良い例になるし。「効率が全てじゃない」っていうのが、医療や司法など他の多くのシステムと同じように強調されてると思う。この場合、分析によって活性化関数を特定できたけど、高次元の問題ではそれが不可能なんだ。AIの魔法は効率にあるんじゃなくて、他の手段では計算できないものを計算可能にすることにあるんだよ。

Cコンパイラの最適化は、手書きのCと比べてどれくらいのスピードアップを達成できるの?-O0の非最適化アセンブラと比べてどう?最適化されたC/アセンブラが実際に必要ないことをして、残りの非効率に影響を与えてるのは何?

Cのコードは163行あるよ。そのうち、-O3では104行がアセンブリ出力に含まれてる。だから、Cコンパイラはさらに約36.2%の命令を削除できるんだ。特別なことはしてなくて、自動ベクトル化とかはしてないよ。今、プロファイルした結果はこうだ: | instrs (aarch64) | time 100k (s) | conway samples (%) | | -O0 | 606 | 19.10s | 78.50% | | -O3 | 135 | 3.45 | 90.52% | 3.45秒ってのは驚きだね、前に測った4.09秒より速いから。もしかしたらPコアとEコアの違いかも。-O0の時、コンパイラが出力するマシンコードはこんな感じ: 0000000100002d6c ldr x8, [sp, #0x4a0] 0000000100002d70 ldr x9, [sp, #0x488] 0000000100002d74 orn x8, x8, x9 0000000100002d78 str x8, [sp, #0x470] これ、めちゃくちゃひどいよね。例えば、-Ogで試すと、-O3と同じ逆アセンブルが得られる。-01でも同じ逆アセンブルになる。アセンブリ(-Og, -01, -03)はCのかなり直接的な翻訳に見える。良くなってるけど、特にすごいわけじゃない(自動ベクトル化はなし): 0000000100003744 orr x3, x3, x10 0000000100003748 orn x1, x1, x9 000000010000374c and x1, x3, x1 0000000100003750 orr x3, x8, x17 よく見ると、実はレジスタのスピリングが意外と少ないんだ。君が本当に聞きたいことは、私が書いたように:

「命令のレイテンシが1サイクルだと仮定すると、2,590 fpsを期待するべきだ。でも、実際にはほぼ10倍の数字を測ってる!どういうこと?」 これの一部は、逆アセンブルで命令を数えるのを間違えてるからだ。ブログ記事では349命令を使ったけど、実際には135なんだ。この新しい数字で計算し直すと、ビットあたり2.11命令、ステップあたり55.3万命令、3.70 gcycles/sで割ると6,690 fpsになる。これは2,590 fpsより良いけど、24,400にはまだ3.6倍遅い。でも、3.6倍は命令レベルの並列性に起因すると思う。これで君の質問に答えられたらいいな。君の文章が大好きだよ、Gwern。

数年前の関連投稿: https://news.ycombinator.com/item?id=25290112 「NN-512は、完全にAVX-512ベクトル化された、人間が読める独立したC実装の畳み込みニューラルネットを生成するオープンソースのGoプログラムです。」

よくやった!めっちゃ楽しんだよ。うちのライブラリでもこういう最適化が必要だな。[0] うちはANDやXORみたいなゲートを使って微分可能な論理ネットワークを構築してるんだ。勾配降下法を使って回路のような構造をトレーニングすることに重点を置いてる。トレーニングしたモデルを効率的なビット並列Cにコンパイルするアイデアは、まさに私たちが探求しているポストトレーニングの最適化なんだ。ソフトゲートをハードなブール論理に戻す(例えば、しきい値処理や記号置換を使って)ことで、推論用の最適化されたコード(C、WASM、HDLなど)を出力するんだ。ライフゲームのカーネルは、論理ベースのネットが本当に輝く例だね。[0]https://github.com/VoxLeone/SpinStep/tree/main/benchmark

いいね。これをLLMでやったら、誰かがめっちゃお金を払ってくれるよ。