這篇文章介紹了由 ARC 的多位研究員、訪問學者及合作者(包括 Zihao Chen、George Robinson、David Matolcsi、Jacob Stavrianos、Jiawei Li 和 Michael Sklar)所完成的工作。感謝 Aryan Bhatt、Gabriel Wu、Jiawei Li、Lee Sharkey、Victor Lecomte 和 Zihao Chen 提供的建議。
在近期關於機械解釋性(mechanistic interpretability)之務實 與雄心 願景的辯論之後,ARC 決定分享一些我們一直在研究的模型。儘管這些模型體積微小,但對於任何雄心勃勃的解釋性願景來說,它們都是極具挑戰性的測試案例。這些模型是經過訓練以執行演算法任務的 RNN 和 Transformer,參數數量從 8 到 1,408 個不等。我們認為目前能大致完全理解的最大模型擁有 32 個參數;而下一個我們投入了大量精力但仍未能完全理解的模型則有 432 個參數。這些模型可以在此取得:
[ AlgZoo GitHub 儲存庫 ]
我們認為,機械解釋性社群中「雄心派」歷來在「完全理解稍微複雜的模型」上的投入不足,而過多關注於「部分理解極其複雜的模型」。先前已有一些旨在完全理解模型的工作,例如針對執行括號匹配 、模加法 以及更通用的群運算 所訓練的模型,但我們仍不認為該領域已接近能夠完全理解我們的模型(至少不是本文所討論的那種意義上的理解)。如果我們有朝一日要完全理解擁有數十億參數的 LLM,我們可能首先需要達到能輕鬆完全理解幾百個參數模型的程度;我們希望 AlgZoo 能激發研究,幫助我們達到那個階段,或幫助我們正視所面臨挑戰的艱巨程度。
造成這種投入不足的一個可能原因是,對於「解釋」和「完全理解」的含義仍存在哲學上的困惑。ARC 目前的觀點是:給定一個針對特定損失函數優化過的模型,對該模型的「解釋」等同於對模型損失的機械式估計(mechanistic estimate) 。我們透過兩種方式之一來評估機械式估計。我們使用驚訝核算 (surprise accounting)來判斷是否達到了完全理解;但在實際操作中,我們僅觀察均方誤差(MSE)隨計算量變化的函數,這使我們能夠將估計值與採樣進行比較 。
在本文的其餘部分,我們將:
回顧我們將機械式估計視為解釋的觀點,包括我們評估機械式估計的兩種方式。
逐步介紹我們研究過的三個 AlgZoo RNN,其中最小的我們已完全理解,最大的則尚未理解。
總結關於 ARC 的方法與雄心勃勃的機械解釋性之間關係的一些思考。
機械式估計即解釋
AlgZoo 中的模型被訓練來執行簡單的演算法任務,例如計算序列中第二大數字的位置。為了說明為什麼這樣的模型具有良好的性能,我們可以對其準確性產出一個機械式估計 。^([1] ) 所謂「機械式」,是指該估計是基於模型的結構進行演繹推理,這與基於採樣的估計不同,後者是從個別案例中對整體性能進行歸納推理。^([2] ) 關於此概念的進一步說明可以在這裡 找到。
並非所有的機械式估計都是高品質的。例如,如果模型必須在 10 個不同的數字中做出選擇,在進行任何分析之前,我們可能會估計模型的準確性為 10%。這雖然是一個機械式估計,但非常粗糙。因此,我們需要某種方式來評估機械式估計的品質。我們通常使用以下兩種方法之一:
均方誤差對比計算量。 評估機械式估計在概念上最直觀的方法,就是詢問它與模型實際準確性的接近程度。機械式估計消耗的計算量越多,它就應該越接近實際準確性。我們的匹配採樣原則 大致是以下猜想:對於任何給定的計算預算,都存在一種機械式估計程序(在給予適當建議的情況下),其均方誤差表現至少與隨機採樣一樣好。
驚訝核算。 這是一個信息論指標,它詢問:既然我們已經獲得了機械式估計,模型的實際準確性還有多令人驚訝?我們透過兩種方式累積「驚訝」:要麼是估計本身執行了某種計算或檢查並得出了令人驚訝的結果,要麼是在考慮了機械式估計及其不確定性後,模型的實際準確性仍然令人驚訝。關於此想法的進一步說明可以在這裡 找到。
驚訝核算之所以有用,是因為它給了我們「完全理解」的概念:一個機械式估計的總驚訝位元數(bits of surprise),與用於選擇該模型的優化位元數一樣少。另一方面,均方誤差對比計算量則與低概率估計 等應用更相關,且更易於操作。我們越來越專注於匹配隨機採樣的均方誤差 ,這仍然是一個具挑戰性的基準,儘管我們通常認為這比實現完全理解要容易。這兩個指標通常密切相關,我們將在下面的案例研究中逐步介紹這兩個指標的示例。
對於 AlgZoo 中大多數較大的模型(包括下面討論的 432 參數模型 M16,10),如果我們能產出一個在「均方誤差對比計算量」指標下與隨機採樣性能相匹配的機械式估計,我們將認為這是一項重大的研究突破。^([3] ) 而要在驚訝核算指標下實現完全理解將是更艱巨的成就,但我們目前對此關注較少。
案例研究:2nd argmax RNNs
AlgZoo 中的模型根據其訓練任務分為四個家族。我們研究時間最長的是訓練來尋找序列中第二大數字位置的模型家族,我們稱之為序列的「2nd argmax」。
該家族的模型由隱藏層大小 $d$ 和序列長度 $n$ 參數化。模型 $M_{d,n}$ 是一個具有 $d$ 個隱藏神經元的單層 ReLU RNN,它輸入長度為 $n$ 的實數序列,並產出長度為 $n$ 的 logit 概率向量。它有三個參數矩陣:
輸入到隱藏矩陣 $W_{hi} \in \mathbb{R}^{d \times 1}$
隱藏到隱藏矩陣 $W_{hh} \in \mathbb{R}^{d \times d}$
隱藏到輸出矩陣 $W_{oh} \in \mathbb{R}^{n \times d}$
$M_{d,n}$ 對輸入序列 $x_0, \dots, x_{n-1} \in \mathbb{R}$ 的 logits 計算如下:
$h_0 = 0 \in \mathbb{R}^d$
$h_{t+1} = \text{ReLU}(W_{hh}h_t + W_{hi}x_t)$ 對於 $t = 0, \dots, n-1$
$\text{logits} = W_{oh}h_n$
圖示如下:
該家族中的每個模型都經過訓練,使用 softmax 交叉熵損失,使最大的 logit 對應於第二大輸入的位置。
我們將在這裡討論的模型是 $M_{2,2}$、$M_{4,3}$ 和 $M_{16,10}$。對於這些模型中的每一個,我們都想了解為什麼訓練後的模型在標準高斯輸入序列上具有很高的準確性。
隱藏層大小 2,序列長度 2
模型 $M_{2,2}$ 可以使用 zoo_2nd_argmax(2, 2) 在 AlgZoo 中加載。它有 10 個參數,準確性幾乎達到完美的 100%,錯誤率約為 13,000 分之一。這意味著模型 logits 之間的差異,
$\Delta(x_0, x_1) := \text{logits}(x_0, x_1)_1 - \text{logits}(x_0, x_1)_0,$
在 $x_1 > x_0$ 時幾乎總是負值,在 $x_0 > x_1$ 時幾乎總是正值。我們希望從機械論的角度解釋為什麼模型具有這種特性。
為此,首先注意到由於模型使用 ReLU 激活函數且沒有偏置(bias),$\Delta$ 是 $x_0$ 和 $x_1$ 的分段線性函數,其中的分段由 $x_0-x_1$ 平面中通過原點的射線界定。
現在,我們可以透過重新縮放隱藏狀態的神經元來「標準化」模型,從而獲得一個完全等效的模型,其 $W_{hi}$ 的條目位於 ${\pm 1}$ 中。完成此操作後,我們看到
$W_{hi} = \begin{pmatrix} +1 \ -1 \end{pmatrix}, W_{hh} \in \begin{pmatrix} [-1, 0) & [1, \infty) \ [1, \infty) & [-1, 0) \end{pmatrix}$ 且 $W_{oh} \in \begin{pmatrix} (0, \infty) & (-\infty, 0) \ (-\infty, 0) & (0, \infty) \end{pmatrix}.$
從這些觀察中,我們可以證明在 $\Delta$ 的每個線性分段上,
$\Delta(x_0, x_1) = a_0x_0 - a_1x_1$
其中 $a_0, a_1 > 0$,此外,$\Delta$ 的分段在 $x_0-x_1$ 平面中根據下圖排列:
這裡,雙箭頭表示邊界位於其相鄰軸與虛線 $x_0 = x_1$ 之間的某處,但我們不需要擔心它在此範圍內的確切位置。
觀察每個線性分段的係數,我們發現:
這意味著:
在 $x_0 = x_1$ 線上方的藍色和綠色區域中,$\Delta(x_0, x_1) < 0$
在 $x_0 = x_1$ 線下方的藍色和綠色區域中,$\Delta(x_0, x_1) > 0$
在兩個黃色區域中,$\Delta(x_0, x_1)$ 大約與 $x_0 - x_1$ 成正比
這些共同暗示了模型具有幾乎 100% 的準確性。更準確地說,錯誤率是單位圓盤中位於模型決策邊界與 $x_0 = x_1$ 線之間的部分,大約是 $1 / (2\pi \times 2^{11}) \approx 1/13,000$。這與模型實證測得的錯誤率非常接近。
均方誤差對比計算量。 僅使用少量的計算操作,我們就能將模型的準確性機械式地估計到 13,000 分之一以內的精度,而這原本需要數萬個樣本。因此,我們的機械式估計比隨機採樣的計算效率高得多。此外,我們只需計算兩個黃色區域中 $a_0$ 和 $a_1$ 的接近程度,就可以輕鬆產出更精確的估計(精確到浮點誤差範圍內)。
驚訝核算。 如此處 所述,總驚訝量分解為解釋的驚訝量加上給定解釋後的驚訝量。給定解釋後的驚訝量接近 0 位元,因為計算基本上是精確的。對於解釋的驚訝量,我們可以回顧我們採取的步驟:
我們「標準化」了模型,這只是用一個完全等效的模型替換了原模型。這完全不依賴於模型的參數,因此不產生任何驚訝量。
我們檢查了模型所有 10 個參數的正負號,以及 $W_{hh}$ 的 4 個條目中每個條目的大小是否大於或小於 1,產生了 14 位元的驚訝量。
我們從中推導出分段線性函數 $\Delta$ 的形式。這是另一個不依賴於模型參數的步驟,因此不產生驚訝量。
我們檢查了 4 個藍色和綠色區域中哪兩個線性係數的大小較大,產生了 4 位元的驚訝量。
我們檢查了 2 個黃色區域中的兩個線性係數在 $2^{11}$ 分之一的範圍內相等,產生了約 22 位元的驚訝量。
加總起來,總驚訝量約為 40 位元。這與用於選擇該模型的優化位元數相當吻合,因為將黃色區域中的線性係數優化到幾乎相等可能是必要的。因此,我們可以相對放心地說我們已經實現了完全理解。
請注意,我們這裡的分析相當「暴力」:我們基本上是一個接一個地檢查 $\Delta$ 的每個線性區域,並在前期做了一些工作以減少所需的總檢查次數。儘管我們認為這在這種情況下構成了完全理解,但對於更深的模型,我們不會得出同樣的結論。這是因為區域的數量會隨著深度呈指數級增長,使得驚訝位元數遠大於模型權重所佔用的位元數(這是用於選擇模型的優化位元數的上限)。同樣的指數爆炸也會阻止我們在合理的計算預算下匹配採樣效率。
最後,值得注意的是,我們的分析允許我們手動構建一個獲得 100% 準確性的模型,取:
$W_{hi} = \begin{pmatrix} +1 \ -1 \end{pmatrix}, W_{hh} = \begin{pmatrix} -1 & +1 \ +1 & -1 \end{pmatrix}$ 且 $W_{oh} = \begin{pmatrix} +1 & -1 \ -1 & +1 \end{pmatrix}.$
隱藏層大小 4,序列長度 3
模型 $M_{4,3}$ 可以使用 zoo_2nd_argmax(4, 3) 在 AlgZoo 中加載。它有 32 個參數,準確性為 98.5%。
我們對 $M_{4,3}$ 的分析與對 $M_{2,2}$ 的分析大致相似,但模型已經深到我們認為純暴力的解釋不再充分。為了處理這個問題,我們利用模型中各種近似的對稱性來減少總計算操作以及解釋的驚訝量。我們的完整分析可以在這些筆記中找到:
在第二組筆記中,我們為模型的準確性提供了兩種不同的機械式估計,它們使用不同的計算量,具體取決於利用了哪些近似對稱性。我們根據兩個指標分析了這兩個估計。我們發現我們能夠大致匹配採樣的計算效率,^([4] ) 並且我們認為我們或多或少有了完全的理解,儘管這一點尚不明確。
最後,我們的分析再次允許我們手動構建一個改進的模型,其準確性達到 99.99%。^([5] )
隱藏層大小 16,序列長度 10
模型 $M_{16,10}$ 可以使用 example_2nd_argmax() 加載。^([6] ) 它有 432 個參數,準確性為 95.3%。
這個模型已經深到暴力方法不再可行。相反,我們在模型隱藏狀態的激活空間中尋找「特徵」(features)。
在重新縮放隱藏狀態的神經元後,我們注意到由神經元 2 和 4 組成的近似孤立子電路,與任何其他神經元的輸出沒有強連接:
$W_{hi}(2,4) \approx \begin{pmatrix} 0 \ +1 \end{pmatrix}, W_{hh}(2,4),(2,4) \approx \begin{pmatrix} +1 & +1 \ -1 & -1 \end{pmatrix}$ 且 $W_{hh}(2,4),(0,1,3,\dots) \approx \begin{pmatrix} 0 & \dots & 0 \ \dots & \dots & \dots \end{pmatrix}.$
由此可知,在將 RNN 展開 $t$ 步後:
神經元 2 大約是 $\max(0, x_0, \dots, x_{t-2})$
神經元 4 大約是 $\max(0, x_0, \dots, x_{t-1}) - \max(0, x_0, \dots, x_{t-2})$
這可以透過使用恆等式 $\text{ReLU}(a-b) = \max(a,b) - b$ 對神經元 4 進行歸納證明。
接下來,我們注意到神經元 6 和 7 與神經元 2 和 4 一起構成了一個更大的近似孤立子電路:
$W_{hi}(6,7) \approx \begin{pmatrix} -1 \ -1 \end{pmatrix}, W_{hh}(6,7),(2,4) \approx \begin{pmatrix} +1 & 0 \ +1 & +1 \end{pmatrix}$ 且 $W_{hh}(6,7),(0,1,3,\dots) \approx \begin{pmatrix} 0 & \dots & 0 \ \dots & \dots & \dots \end{pmatrix}.$
使用相同的恆等式,可知在展開 RNN $t$ 步後:
神經元 6 大約是 $\max(0, x_0, \dots, x_{t-3}, x_{t-1}) - x_{t-1}$
神經元 7 大約是 $\max(0, x_0, \dots, x_{t-1}) - x_{t-1}$
我們可以繼續下去,將神經元 1 加入子電路:
$W_{hi}(1) \approx (-1), W_{hh}(1),(2,4,6,7) \approx (+1, +1, +1, -1)$ 且 $W_{hh}(1),(0,1,3,\dots) \approx (0 \dots).$
因此,在展開 RNN $t$ 步後,神經元 1 大約是
$\max(0, x_0, \dots, x_{t-4}, x_{t-2}, x_{t-1}) - x_{t-1},$
形成了另一個「留一最大值」(leave-one-out-maximum)特徵(減去最近的輸入)。
事實上,透過推廣這個想法,我們可以手動構建一個使用 22 個隱藏神經元來形成所有 10 個留一最大值特徵的模型,並利用這些特徵達到 99% 的準確性。^([7] )
然而,遺憾的是,要更進一步非常具有挑戰性:
我們利用了 5 個隱藏神經元的近似權重稀疏性,但其餘 11 個隱藏神經元中的大多數連接更為密集。
我們產出了一個具有高準確性的手動構建模型,但我們尚未在訓練模型的隱藏神經元與手動構建模型的隱藏神經元之間建立大部分的對應關係。
我們在分析中使用了近似值,但尚未處理近似誤差,隨著我們考慮更複雜的神經元,誤差會變得越來越顯著。
從根本上說,儘管我們對模型有一定的了解,但我們的解釋是不完整的,因為我們還沒有將這種理解轉化為對模型準確性的充分機械式估計。
最終,為了產出一個能與採樣競爭(或構成完全理解)的該模型準確性機械式估計,我們預計必須以某種方式將這種特徵分析與用於 $M_{2,2}$ 和 $M_{4,3}$ 模型的「利用對稱性後的暴力法」元素相結合,並主要以演算法的方式來完成。這就是為什麼我們認為產出這樣的機械式估計是一個艱巨的研究挑戰。
關於此模型進一步討論的筆記可以在這裡找到:
Zihao Chen 的 RNNs for the 2nd argmax 以及 補充筆記本
結論
AlgZoo 中的模型很小,但除了其中最小的模型外,要機械式地估計其準確性並與採樣競爭,更不用說在驚訝核算的意義上完全理解它們,都是一項相當大的挑戰。與此同時,AlgZoo 模型的訓練任務可以輕易被 LLM 完成,因此完全理解它們實際上是實現雄心勃勃的 LLM 解釋性的先決條件。總體而言,我們熱切希望看到其他以雄心為導向的研究人員探索我們的模型,具體來說,我們期待看到在均方誤差對比計算量意義上更好的機械式估計。我們提出的一個具體挑戰如下。
挑戰 :設計一種機械式估計 432 參數模型 $M_{16,10}$^([8] ) 準確性的方法,使其在均方誤差對比計算量方面與隨機採樣的性能相匹配。衡量均方誤差的一個廉價方法是向模型的權重添加噪聲(足以顯著改變模型的準確性),並檢查該方法在隨機噪聲模型上的平均平方誤差。^([9] )
ARC 更廣泛的方法與此有何關係?我們在這裡展示的分析是相對傳統的機械解釋性,但我們認為這種分析主要是一種熱身。最終,我們尋求一種可擴展的、演算法化的方法來產出機械式估計,我們在最近的工作 中一直在追求這一目標。此外,我們的雄心在於,我們希望充分利用模型中存在的結構來機械式地估計任何感興趣的量。^([10] ) 因此,我們的方法最準確的描述是「雄心勃勃」且「機械式」的,但或許不完全是「解釋性」。
從技術上講,模型是為了最小化交叉熵損失(帶有少量的權重衰減)而訓練的,而不是為了最大化準確性,但兩者密切相關,因此我們略過此區別。↩︎
「機械式估計」一詞基本上與此處 使用的「啟發式解釋」或此處 使用的「啟發式論證」同義,不同之處在於它更自然地指代數值輸出而非產出它的過程,並且具有我們現在更偏好的其他內涵。↩︎
單個模型的估計可能因偶然因素而接近,因此該方法應在隨機種子平均後與採樣相匹配。↩︎
為了評估我們方法的均方誤差,我們向模型的權重添加噪聲,並檢查我們方法在隨機噪聲模型上的平均平方誤差。↩︎
此手動構建模型可以使用 handcrafted_2nd_argmax(3) 在 AlgZoo 中加載。感謝 Michael Sklar 的通信促成了此構造。↩︎
我們將此模型與「官方」模型動物園分開處理,因為它是在我們標準化代碼庫之前訓練的。感謝 Zihao Chen 最初訓練並分析了此模型。動物園中可以使用 zoo_2nd_argmax(16, 10) 加載的模型具有相同的架構,且可能相當相似,但我們尚未對其進行分析。↩︎
此手動構建模型可以使用 handcrafted_2nd_argmax(10) 在 AlgZoo 中加載。請注意,此手動構建模型比訓練模型 $M_{16,10}$ 擁有更多的隱藏神經元。↩︎
我們所指的特定模型可以使用 example_2nd_argmax() 在 AlgZoo 中加載。具有相同架構的其他 2nd argmax 模型(好的方法也應該在這些模型上表現良好)可以使用 zoo_2nd_argmax(16, 10, seed=seed) 加載,其中 seed 等於 0, 1, 2, 3 或 4。↩︎
衡量均方誤差的一個更好但更昂貴的方法是,對用於訓練模型的隨機種子進行平均。↩︎
我們在這種意義上是雄心勃勃的,是因為我們的最壞情況理論方法論 ,但與此同時,我們比內在理解更關注低概率估計 等應用,對於這些應用,部分成功就能帶來務實的勝利。↩︎