newsence
來源篩選

Weight-Sparse Circuits May Be Interpretable Yet Unfaithful

Lesswrong

I replicate Gao et al.'s findings that weight-sparse models yield smaller, interpretable circuits, but provide evidence that these circuits may be unfaithful to the model's true computations.

newsence

權重稀疏模型的可解釋性與真實性存疑

Lesswrong
19 天前

AI 生成摘要

我重現了 Gao 等人的研究結果,證實權重稀疏模型能產生更小且具可解釋性的電路,但我同時提出證據指出,這些電路可能無法忠實反映模型的真實計算過程。

TLDR:最近,Gao 等人訓練了具有稀疏權重的 Transformer,並引入了一種剪枝演算法來提取解釋模型在特定任務上表現的電路。我複現了他們的主要結果,並提出證據表明這些電路對模型的「真實計算」並不忠實。

這項工作是作為 Anthropic 研究員計畫(Anthropic Fellows Program)的一部分,在 Nick Turner 和 Jeff Wu 的指導下完成的。

前言

最近, 提出了一種令人興奮的方法,用於訓練設計上即具備可解釋性的模型。他們訓練的 Transformer 只有極小比例的權重為非零值,並發現對這些稀疏模型在特定任務上進行剪枝可以產生具備可解釋性的電路。他們的核心主張是,這些權重稀疏(weight-sparse)的模型比普通的稠密模型更具可解釋性,且擁有更小的任務特定電路。在下文中,我複現了支持這些主張的主要證據:在給定的任務損失下,訓練權重稀疏模型確實傾向於產生比稠密模型更小的電路,且這些電路看起來也具備可解釋性。

然而,有些理由讓我們擔心這些結果並不意味著我們捕捉到了模型的完整計算。例如,先前的研究 [, ] 發現,類似的遮罩技術即使應用於隨機權重的模型,也能在視覺任務上取得良好表現。因此,我們可能會擔心剪枝方法會「找」出原本模型中並不存在的電路。我提出的證據顯示這種擔憂是有道理的——具體而言,剪枝後的電路可以:

  • 在荒謬的任務上實現低交叉熵(CE)損失;
  • 即使原始模型的注意力模式明顯是非均勻的,電路卻使用均勻的注意力模式來解決任務;
  • 重新利用節點,使其執行與原始模型中不同的功能;
  • 在與剪枝所用分佈略有不同的輸入上,表現與原模型截然不同。

總體而言,這些結果表明,從權重稀疏模型中提取的電路,即使看起來可解釋,也應對其忠實性(faithfulness)進行審查。更廣泛地說,在可解釋性研究中,我們不應純粹追求推動電路大小與任務表現的帕累托前沿(Pareto frontier)^(),因為這樣做可能會產生對模型行為的誤導性解釋。

在本篇文章中,我將簡要回顧我為測試稀疏模型方法而設計的任務,展示 Gao 等人主要結果的基本複現,然後提供四條證據,表明他們的剪枝演算法產生了不忠實的電路。

我用於訓練和分析權重稀疏模型的程式碼在。它與 Gao 等人的開源類似,但額外實現了剪枝演算法、「橋接」(bridges)訓練、多 GPU 支援以及一個互動式電路檢視器。在我的測試中,訓練速度也快了約 3 倍。

任務

我透過在以下三個自然語言任務上進行剪枝來提取權重稀疏電路。有關訓練和剪枝的更多細節,請參見附錄。

任務 1:代名詞匹配

提示詞的形式為 "when {name} {action}, {pronoun}"

例如:

  • "when leo ran to the beach, he"
  • "when mia was at the park, she"

名字取自預訓練數據集 () 中最常見的 10 個名字(5 男 5 女)。^() 用於剪枝的任務損失是預測最後一個標記("he" 或 "she")的交叉熵(CE)。

任務 2:簡化版 IOI

我使用了標準間接對象識別(Indirect Object Identification, IOI)任務的簡化版本。提示詞的形式為 "when {name_1} {action}, {name_2} {verb} {pronoun matching name_1}"。例如:

  • "when leo went to the shop, mia urged him"
  • "when rita was at the house, alex hugged her"

用於剪枝的任務損失是二元交叉熵:我們首先計算模型僅針對 "him" 和 "her" 的機率分佈(僅對這兩個 logit 進行 softmax),然後使用這些機率計算交叉熵。

任務 3:問號

提示詞是來自預訓練數據集的短句,以句號或問號結尾。篩選條件為:1) 稠密模型以 p > 0.3 的機率預測正確的結尾標記(句號或問號),且 2) 當僅限於句號和問號時,稠密模型分配給正確標記的機率 > 0.8。例如:

  • "why do you want that key?"
  • "that is why I want the key."

用於剪枝的任務損失是二元交叉熵,僅對 "?" 和 "." 的 logit 進行 softmax。

結果

關於提取稀疏電路時層歸一化(Layer Norm)作用的相關研究,請參見附錄。

產生稀疏且具可解釋性的電路

零消融(Zero ablation)比均值消融(Mean ablation)產生更小的電路

在剪枝時,Gao 等人將被遮罩的激活值設置為其在預訓練集上的平均值。我發現,在給定的損失下,零消融通常會導致更小的電路(即在下方除了第三行最右側列之外的所有子圖中)。因此,在計畫的其餘部分我使用了零消融。

權重稀疏模型通常擁有更小的電路

Gao 等人的圖 2 基本上得到了複現。在代名詞和 IOI 任務中,在給定損失下,稀疏模型的電路比稠密模型更小。在問號任務中,只有兩個稀疏模型的電路比稠密模型小,且即便如此,規模的縮減程度也低於其他兩個任務。

權重稀疏電路看起來具備可解釋性

您可以在查看每個任務的電路。懸停或點擊節點會顯示其在原始模型或剪枝模型中的激活值。以下是我對 IOI 電路運作方式的簡要總結;我會在附錄中詳細介紹另外兩個任務的電路。這裡介紹的每個電路都是從 $d_{model}=1024$ 的模型中提取的;我沒有像這樣仔細檢查從其他模型提取的電路。本節中顯示的所有逐標記(per-token)激活值均取自剪枝後的模型,而非原始模型。

IOI 任務 ()

以下是來自第 1 層 attn_out 的一個重要節點。當 name_1 為女性時,它呈正向激活;當為男性時,呈負向激活。接著它會抑制 "him" 的 logit。

為了了解該節點的激活是如何計算的,我們可以檢查它讀取的數值向量(value-vector)節點,以及相應的鍵(key)和查詢(query)節點。下圖顯示的數值向量節點在男性名字上呈負向激活:

這裡有兩組查詢-鍵對(query-key pairs)。第一組查詢向量始終具有負激活(未顯示)。相應的鍵節點激活為負,其強度大致隨標記位置增加而減小:

另一組查詢-鍵對執行相同的操作,但激活為正。因此,注意力頭最強烈地關注提示詞的第一部分,所以只有當數值向量節點出現在句子開頭附近時,attn_out 節點才會獲得較大的貢獻。具體而言,當 name_1 為男性時,attn_out 節點會獲得較大的負向貢獻。另一個未在此顯示的數值向量節點在 name_1 為女性時提供正向貢獻。這解釋了我們上面看到的 attn_out 節點的激活模式。

審查電路的忠實性

剪枝在荒謬任務上也能實現低任務損失

我修改了代名詞任務,使得當名字為男性時,目標標記為 "is";當名字為女性時,目標為 "when"。例如:

  • "when rita went to the woods, when"
  • "when leo went to the woods, is"

與標準代名詞任務一樣,任務損失僅為標準交叉熵損失(即對所有 logit 進行 softmax,我沒有使用二元交叉熵)。這是一個荒謬的任務,但剪枝後的模型僅用約 30 個節點就實現了任務損失 < 0.05(意味著準確率 > 95%)。

在普通的代名詞任務上,達到類似損失大約需要 10 個節點。因此,荒謬任務確實需要比真實任務更大的電路,這在某種程度上令人安心。即便如此,任何電路竟然能在這種荒謬任務上獲得如此低的損失,這點仍令人擔憂,而且 30 個節點真的不算多。

您可以在查看這個荒謬任務的電路。

重要的注意力模式在剪枝模型中可能缺失

代名詞電路僅在第 1 層第 7 頭擁有注意力節點。在原始模型中,正如人們所預期的,該頭強烈地從最後一個標記關注到名字標記("rita")。但在剪枝後的模型中,其注意力模式是均勻的(因為沒有查詢或鍵向量節點):

剪枝後的電路是如何在不計算注意力模式的情況下蒙混過關的?它是透過讓其所有的數值向量節點都成為那些在名字上強烈觸發、而在其他地方觸發極弱的節點來實現的。因此,即使該頭關注所有標記,它也只會從名字標記中轉移資訊。這種機制在原始模型中是不可用的。我們發現的電路遺漏了原始模型運作的一個關鍵部分。

節點在剪枝模型中可能扮演不同的角色

範例 1: 以下是來自的第 0 層 1651 號節點的激活值。左圖顯示其在剪枝模型中的激活,它在女性名字上呈負向激活(紅色)。右圖顯示其在原始模型中的激活,它在男性名字上呈正向激活(藍色)。在這兩種情況下,它對所有非名字標記的激活都非常接近於零。這很奇怪:節點在剪枝後獲得了不同的含義。

範例 2: 以下是來自的第 1 層 attn_out 244 號節點的激活值。在剪枝模型中,該節點在 name_1(第一個出現的名字)為女性的語境下呈正向激活,在男性語境下呈負向激活。特別是,只有當 name_1 為女性時,最後一個標記的激活才為正,且正如預期的,該節點直接抑制了 "him" 的 logit。因此,在剪枝模型中,該節點扮演的功能角色是「檢測 name_1 的性別並增強/抑制相應的 logit」。但在原始模型中,最後一個標記的激活並不依賴於 name_1,因此它不可能扮演相同的功能角色。

範例 3: 以下是來自的第 1 層 mlp_out 1455 號節點的激活值。在剪枝模型中,該節點是一個疑問句分類器:它在疑問句上呈負向激活,在其他地方大致為零。它被用來抑制 "?" 的 logit。但在原始模型中,它並不是疑問句分類器。特別是,它在句子最後一個標記上的激活並不能預測該句子是否為疑問句,因此它不可能在幫助提升正確的 logit。

剪枝後的電路可能無法像基礎模型那樣泛化

回想一下,IOI 提示詞看起來像 "when {name_1} was at the store, {name_2} urged __"。我們使用僅包含 name_1name_2 性別相反的提示詞訓練集進行剪枝。在訓練集上表現良好的電路顯然有兩種:

  • 好的電路:輸出與 name_1 性別一致的代名詞。
  • 壞的電路:輸出與 name_2 性別相反的代名詞。

令 $P_{correct}$ 為分配給正確目標標記的平均機率,我們僅透過對 "him" 和 "her" 標記進行 softmax 來計算機率。在這裡,我專注於 $d_{model}=3072$ 的模型,它在性別相反的提示詞上正確完成任務的機率為 89%,在性別相同的提示詞上為 81%。

我使用不同的隨機遮罩初始化和數據順序運行了 100 次剪枝。下圖顯示了性別相反提示詞(左)和性別相同提示詞(右)的 $P_{correct}$ 結果分佈。我過濾掉了未能達到交叉熵 < 0.15 的運行,總共剩下 77 個種子。
剪枝經常只找到「壞電路」(見性別相同直方圖中 0 處的大峰值)。這很糟糕,因為實際的原始模型在性別相同的情況下 $P_{correct}=0.81$,因此它一定是在使用好的電路。

另外,使用相同的超參數但不同的隨機種子進行剪枝,會導致電路具有完全不同的分佈外(OOD)行為,這也令人擔憂。

結論

上述結果提供的證據表明,Gao 等人的剪枝方法雖然能找到規模小、具可解釋性且任務損失低的電路,但這些電路對於模型真正的運作方式是不忠實的。這些結果並未過多評論權重稀疏訓練本身是否為一個有前景的方向;它們僅顯示剪枝演算法存在缺陷。

我的主要心得是,我們不應純粹旨在改善損失與電路大小的帕累托前沿。僅在這一指標上進行爬山演算法(Hill-climbing)很可能會產生看起來吸引人但實際上具有誤導性的機械論解釋。例如,零消融改善了前沿,所以我很早就切換到了它。但事後看來,均值消融可能會產生更忠實的電路(代價是對於像代名詞性別匹配這樣簡單的任務,會產生約 100 個節點的電路,這比我預期的要多得多)。

我認為權重稀疏研究方向的自然下一步是:1) 想出一個好的忠實性指標^()(像因果擦除 [causal scrubbing] 這樣的想法似乎方向正確,但可能過於嚴格);2) 弄清楚如何修改剪枝演算法,以提取根據該指標是忠實的電路;3) 檢查當我們使用修改後的剪枝演算法時,Gao 等人的主要結果——權重稀疏模型具有更小的電路——是否仍然成立。^()

我也感興趣於對歸因圖(attribution graphs)的忠實性進行類似的審查。^() 我預計歸因圖會比我在本研究中發現的電路更忠實(粗略地說,因為它們的剪枝方式並非直接針對下游交叉熵損失進行優化),但應該有人去驗證這一點。我特別感興趣於尋找歸因圖對模型在未見過的提示詞上的行為做出定性錯誤預測的情況(類似於上述 IOI 任務中剪枝找到「壞電路」的情況)。

附錄 A:訓練與剪枝細節

我對權重稀疏訓練的實現幾乎完全複製自 Gao 等人。這裡我僅提到幾點不同之處和感興趣的地方:

  • 我訓練了具有各種大小和稀疏度的兩層模型:

| $d_{model}$ | 非零比例 | 非零參數數量 |
| :--- | :--- | :--- |
| 128 | 1 (稠密) | 1.4M |
| 1024 | 1/64 | 0.5M |
| 3072 | 1/200 | 1.3M |
| 4096 | 1/500 | 0.9M |

  • 遺憾的是,我沒有確保每個模型具有相同數量的非零參數。然而,下文中我唯一比較不同模型的時候是在比較它們的電路大小 vs 任務損失帕累托曲線,這只是對 Gao 等人主要結果的複現,並非本文的重要部分。
  • 每個模型都在來自 的 2B 個標記上進行訓練。
  • 這些權重稀疏模型是與映射自/向稠密模型的「橋接」一起訓練的。這是我的一個錯誤選擇,因為我的結果最終完全沒有用到橋接,且它們增加了複雜性。即便如此,我預計獨立權重稀疏模型的結果基本上會是一樣的。
  • 我在每個殘差流位置施加了輕微(25%)的激活稀疏性(而 Gao 等人僅在 mlp_out 等其他點施加),因為這能略微改善損失。
  • 在訓練的前半段,Gao 等人將非零參數的比例從 1(完全稠密)線性衰減到其目標值。我使用了指數衰減計畫,因為這能略微改善損失。

對於剪枝,我再次緊隨 Gao 等人的做法,僅有微小差異:

  • 如正文所述,Gao 等人透過均值消融來遮罩節點,而我發現零消融能產生更小的電路。
  • 我研究的任務類型涉及自然語言而非程式碼。

附錄 B:代名詞與問號電路詳解

代名詞任務 ()

該電路中的所有計算都經過第 1 層的兩個數值向量節點。下圖顯示的一個在男性名字上呈負向激活。由於沒有查詢或鍵節點,注意力模式是均勻的,名字之後的每個標記都從該數值向量節點獲得相同的貢獻。追蹤到 logit 節點可以發現,該數值向量節點增強了 "she" 並抑制了 "he"。另一個數值向量節點執行相同的操作,但性別相反。
*

請注意,第 1 層中的 mlp_out 節點沒有連接到上游電路節點的傳入權重,因此它們的激活是常數偏差(constant biases)。

問號任務 ()

最早充當疑問句分類器的節點位於第 1 層的 attn_out

attn_out 節點讀取一個數值節點,該節點在「疑問詞」("why"、"are"、"do" 和 "can")上呈正向激活,在代名詞上呈負向激活:

該頭的查詢節點具有正向激活(未顯示)。鍵節點的激活大致^()隨標記位置增加而減小:

因此,如果提示詞包含 "do you",則該頭更強烈地關注 "do",因此 attn_out 從疑問詞 "do" 獲得較大的正向貢獻,而從代名詞 "you" 僅獲得較小的負向貢獻。另一方面,如果提示詞包含 "you do",則該頭更強烈地關注 "you",因此 attn_out 從 "you" 獲得較大的負向貢獻,而從 "do" 僅獲得較小的正向貢獻。

綜合以上,attn_out 節點在包含 "do you" 的提示詞上呈正向激活,在包含 "you do" 的提示詞上呈負向激活,對於疑問詞和代名詞的其他組合也是如此。因此,attn_out 節點的功能是作為疑問句檢測器。

附錄 C:層歸一化(Layernorm)的作用

剪枝後的模型僅用極少數節點就能獲得極低的任務損失。這有一種可能的「作弊」方式:

剪枝會降低激活值的範數(norm)。特別是,下圖顯示在最後一層之後的殘差流 RMS——即輸入到最終層歸一化的激活值——在剪枝模型中更小。因此,最終層歸一化在剪枝模型中對激活值的放大倍數比在原始模型中更大。

現在,假設原始模型有許多節點,每個節點在最終層歸一化之前都向「正確方向」寫入少量資訊(我指的是能透過解嵌入 [unembed] 來增強正確 logit 的方向)。剪枝後的電路僅包含少數這些節點,因此它只向正確方向寫入少量資訊。但它能蒙混過關,因為最終層歸一化將激活值放大了許多,以至於即使是正確方向上的一個微小分量也會強烈增強正確的 logit,從而導致良好的交叉熵損失。

下面,我們將常規剪枝與修改後的版本進行比較,在修改版中我們凍結了層歸一化比例(即層歸一化除以激活值的值)。也就是說,對於每一批數據,我們運行原始模型,保存其所有的層歸一化比例,然後在剪枝模型的正向傳播中將其植入。正如上述分析所預測的,在給定損失下,凍結層歸一化會導致大得多的電路。

對於較大模型 ($d_{model}=3072$) 的 IOI 任務,凍結層歸一化(下圖)比標準剪枝(上圖)在性別相同提示詞上具有更好的泛化能力:

然而,對於較小模型 ($d_{model}=1024$),結果卻相反。也就是說,凍結層歸一化導致電路的泛化能力比未凍結時更差。我發現這個結果令人驚訝。

我認為在剪枝過程中凍結層歸一化在「道義上是正確的」,這樣模型就不能以描述的方式作弊。但這樣做似乎並不能完全解決忠實性問題(見直接上方的 IOI $d_{model}=1024$ 結果)。

關於本附錄結果的最後一個提醒:對於每個模型和任務,我進行了 搜索以找到剪枝的最佳超參數,然後將這些最佳超參數用於 100 次隨機種子剪枝運行。情況可能是,例如對於 $d_{model}=1024$,我們恰好找到了導致性別相同提示詞泛化不佳的「不幸」超參數,而對於 $d_{model}=3072$ 模型,我們找到了「幸運」的超參數。換句話說,這 100 個種子或許並沒有我們希望的那樣去相關。

  • ^() 其中較小的電路規模和較低的交叉熵損失更好。在這項工作中,電路規模指的是電路中的節點數量。

  • ^() 如果你感興趣的話,名字分別是 Leo, Samuel, Jose, Peter, Alex, Mia, Rita, Kim, Maria, Lily。

  • ^() 例如,我嘗試將模型中每個節點的「重要性」定義為當該節點被消融時任務損失的增加量,然後計算電路中包含的前 20 個(假設)最重要節點的比例。我觀察的所有電路在這一指標上得分都很低。但那些未在電路中發現的所謂重要節點,通常具有稠密且看似不可解釋的激活(我在均值消融和零消融中都看到了這一點),我懷疑它們扮演的是非常底層的「保持激活在分佈內」的角色,類似於常數偏差項。因此,我不確定透過上述定義的重要性得分來判斷忠實性是否完全正確。

  • ^() 如果你有興趣從事這類工作,請透過 jacobcd52 at gmail dot com 與我聯繫。

  • ^() 例如,我很想看到更多像一樣的研究。

  • ^() 激活值僅「大致」隨標記位置增加而減小這一事實,讓我懷疑我的機械論解釋有所遺漏。注意力模式可能比單純的「隨便一個大致遞減函數」更細微,但我還沒有努力去理解它們。