スパース回路を通じてニューラルネットワークを理解する
私たちは、モデルの振る舞いをより適切に理解するため、よりシンプルかつ追跡可能なステップで思考できるようにモデルを訓練しました。
ニューラルネットワークは今日の最も強力な AI システムを支えていますが、そのすべてを理解するのは依然として困難です。これらのモデルは明示的なステップバイステップの指示によって作成されているのではなく、モデル自身が、タスクを習得するまで何十億もの内部接続、すなわち「重み」を調整して学習しているのです。私たちは学習のルールを設計しますが、それによって生じる具体的な振る舞いを設計しているわけではありません。その結果、人間には容易に解読できない密なつながりの網が形成されます。
AI システムがより高性能になり、科学・教育・医療における意思決定に現実の影響を与えるようになるにつれ、それらがどのように動作しているのかを理解することは不可欠です。解釈可能性とは、モデルがある出力を生成した理由を理解するのに役立つ手法を指します。これを実現する方法は数多く存在します。
たとえば、推論モデルには最終的な答えに至るまでの過程を説明するよう動機付けられています。思考の連鎖(Chain of Thought)による解釈可能性は、これらの説明を活用してモデルの振る舞いを監視します。これはすぐに役立つものであり、現在の推論モデルの思考の連鎖からは、欺瞞などの懸念される行動に関して参考となる情報を得ることができます。しかし、この性質に全面的に依存するのは脆弱な戦略であり、時間の経過とともに破綻する可能性があります。
一方、本研究の焦点である機械論的解釈可能性は、モデルの計算処理を完全にリバースエンジニアリングすることを目的としています。これまでのことろ、すぐに役立つものではありませんでしたが、原理的にはモデルの振る舞いについてより完全な説明を提供できる可能性があります。機械論的解釈可能性では、モデルの振る舞いを最も詳細なレベルで説明しようとすることで、仮定を最小限にし、より確かな自信を得ることができます。しかし、低レベルの詳細から複雑な振る舞いの説明へと至る道のりは、非常に長く困難です。
解釈可能性は、より優れた監視を実現したり、安全性に問題がある行動や戦略的に一致していない行動を早期に警告したりするなど、いくつかの重要な目的を支えています。また、スケーラブルな監視、敵対的学習、レッドチーミングといった、他の安全性向上の取り組みを補完するものでもあります。
本研究は、モデルをより解釈しやすくする方法で訓練することが可能であることを示します。私たちはこの研究を、密結合ネットワークに対する事後分析を補完する有望なアプローチと捉えています。
これは非常に野心的な試みであり、この研究から最も強力なモデルの複雑な振る舞いを完全に理解するまでの道のりは非常に長いものです。それでも、単純な振る舞いに対しては、この手法で訓練されたスパースモデルに、理解可能で、かつ十分な処理を実行できる、小さく分離された回路を構成できることがわかりました。これは、そのメカニズムを理解できるような、より大規模なシステムを訓練するための現実的な道筋が存在するかもしれないことを示唆しています。
これまでの機械論的解釈可能性の研究は、密で絡み合ったネットワークから出発し、それを解きほぐそうとしてきました。これらのネットワークでは、個々のニューロンが数千もの他のニューロンとつながっています。多くのニューロンは複数の異なる機能を担っているように見えるため、理解するのはほぼ不可能です。
しかし、ニューロン数を大幅に増やしつつ、各ニューロンの接続を数十個だけに制限したネットワークのような、最初から絡まりのないニューラルネットワークを訓練したらどうなるでしょうか?そうすれば、得られるネットワークはより単純で、理解しやすくなるかもしれません。これがこの研究の中心的な仮説です。
この原則を念頭に、GPT‑2 に近いアーキテクチャを持つ言語モデルを訓練しましたが、1つだけ小さな変更を加えました。それは、モデルの重みの大部分をゼロに強制したことです。これにより、モデルはニューロン間に生じ得る接続のごく一部しか利用できなくなります。これは非常に単純な変更ですが、モデル内部の計算を大幅に解きほぐすことができると考えられます。
通常の密結合ニューラルネットワークでは、それぞれのニューロンが次の層のあらゆるニューロンと接続しています。スパースモデルでは、各ニューロンは次の層のごく少数のニューロンとしか接続しません。これにより、各ニューロンおよびネットワーク全体がより理解しやすくなることが期待されます。
私たちは、スパースモデルの計算がどの程度「解きほぐされている」かを測定したいと考えました。そのために、いくつかの単純なモデルの振る舞いを対象に、それぞれの振る舞いを担う部分(「回路」と呼んでいます)を分離できるかどうかを確認しました。
そこで、手作業で選び抜いた一連のシンプルなアルゴリズムタスクを用意しました。各タスクについて、そのタスクを実行可能な最小の回路にまでモデルをプルーニングし、その回路がどれほど単純かを調べました(詳細については、論文(新しいウィンドウで開く)をご覧ください)。その結果、より大きく、よりスパース(疎)なモデルを訓練することで、より単純な回路を備えた、より高性能なモデルを作成できることがわかりました。
複数のモデルについて、解釈可能性と性能をプロットしました(左下ほど良い)。スパースモデルのサイズが固定されている場合、スパース度を高める(より多くの重みをゼロに設定する)と性能は低下しますが、解釈可能性は向上します。モデルサイズを大きくすると、この限界を押し広げることができ、能力と解釈可能性を兼ね備えたより大規模モデルを構築できることが示唆されています。
具体例として、Python コードで文字列を正しい種類の引用符で閉じるタスクを考えてみましょう。Python では、'hello' は一重引用符で、"hello" は二重引用符で終わらなければなりません。モデルが、どの種類の引用符で文字列が開始されたかを記憶し、末尾でそれを再現すれば、このタスクを解決できます。
私たちの最も解釈可能なモデルには、まさにこのアルゴリズムを実現する分離された回路が存在するように見えます。

文字列が一重引用符で終わるか二重引用符で終わるかを予測する、スパーストランスフォーマーの回路例。この回路は、5つの残差チャネル(縦の灰色線)、レイヤー0の2つの MLP ニューロン、レイヤー10の1つのアテンションクエリ/キーチャネル、1つのバリューチャネルを使用します。モデルは(1)一重引用符を1つの残差チャネルに、二重引用符を別の残差チャネルにそれぞれエンコードし、(2)MLP 層を使って「任意の引用符を検出するチャネル」と「一重引用符と二重引用符を分類するチャネル」に変換し、(3)アテンション操作で間にあるトークンを無視し、直前の引用符を探してその種類を最終トークンへコピーし、(4)対応する閉じ引用符を予測します。
私たちの定義では、タスクを実行するには上記の接続だけで十分であり、モデルの残りの部分を取り除いても、この小さな回路は正しく動作します。また、それらの接続は必要条件でもあり、これらの数本のエッジを削除するとモデルは動作しなくなります。
さらに、より複雑な振る舞いも調べました。たとえば、以下の変数バインディングのような振る舞いについて、その回路を完全に説明するのはより難しいことです。それでも、モデルの振る舞いを予測できる比較的単純な部分的説明までは十分に達成することができます。
もう1つの回路例(詳細は簡略化されています)。current という変数の型を判断するため、1つ目のアテンション操作は、変数が定義される際に変数名を set() トークンへコピーし、後の別の操作が set() トークンからその型情報を変数の後続の使用箇所へコピーします。これにより、モデルは正しい次のトークンを推論できます。
本研究は、モデルの計算をより理解しやすくするという大きな目標に向けた初期段階にすぎませんが、まだ長い道のりがあります。スパースモデルは最先端モデルよりもはるかに小さく、その計算の大部分はいまだ解釈されていません。
次のステップとして、これらの手法をより大規模なモデルに拡張し、モデルの振る舞いをより詳しく説明できるようにすることを目指しています。強力なスパースモデルの複雑な推論の基盤となる回路パターンを整理することで、最先端モデルを調査する際にどの部分を対象にすればよいかについて理解を深めることができる可能性があります。
スパースモデルの訓練に関する非効率性を克服するには、2つの方法があると考えています。1つは、スパースモデルをゼロから訓練するのではなく、既存の密結合モデルからスパース回路を抽出する方法です。密結合モデルは本質的にスパースモデルよりも展開が効率的です。もう1つは、モデルを解釈可能にするためのより効率的な新しい訓練手法を開発することです。これにより、実際の運用に導入しやすくなる可能性があります。
ここで得られた知見は、より高性能なシステムでもそのまま通用することを保証するものではありませんが、初期の結果としては有望です。私たちの目標は、確実に解釈できるモデルの範囲を徐々に広げ、将来のシステムをより分析しやすく、デバッグしやすく、評価しやすくするためのツールを構築することです。


