newsence
來源篩選

In-context learning of representations can be explained by induction circuits

Lesswrong

This post argues that induction circuits, a known mechanism for bigram recall, are sufficient to explain how large language models mirror graph structures in their internal representations during in-context learning tasks.

newsence

表徵的上下文學習可由歸納電路解釋

Lesswrong
大約 5 小時前

AI 生成摘要

這篇文章主張歸納電路這種已知的二元語法召回機制,足以解釋大型語言模型在執行上下文學習任務時,其內部表徵如何反映出圖形結構的幾何特性。

這是我在 發表內容的轉載。所有程式碼與實驗均可在 取得。

摘要

指出,當大型語言模型(LLM)處理圖形上的隨機遊走(random walks)時,其內部表示會反映出底層圖形的結構。作者對此進行了廣義的解釋,認為 LLM 可以「操縱其表示,以反映完全由上下文指定的概念語義」。在本篇文章中,我們深入研究其底層機制,並提出一個更簡單的解釋。我們認為,歸納電路(induction circuits,; )——一種廣為人知的上下文二元語法(bigram)召回機制——足以解釋 Park 等人觀察到的任務表現與表示幾何結構。

Park et al., 2025 的回顧與重現

我們首先描述 的實驗設置,並在 上重現其主要結果。

圖 1. Park et al. 研究概覽
(a) 網格追蹤任務使用一個 4×4 的單字網格。(b) 模型觀察網格上的隨機遊走(例如:apple bird milk sand sun plane opera ...),其中連續的單字始終是鄰居。隨著序列長度增加,模型開始根據圖形結構預測有效的下一個單字。(c) 令人驚訝的是,模型的有效標記(token)表示幾何結構反映了網格結構:模型在激活空間中將每個節點表示在其鄰居附近。圖表重製自 Park et al.

網格追蹤任務

Park 等人引入了「上下文圖形追蹤」(in-context graph tracing)任務。該任務涉及一個預定義的圖形 $G = (V, E)$,其中節點 $V$ 透過標記(例如:apple, bird, math 等)來引用。圖形的連接結構 $E$ 的定義獨立於標記之間的任何語義關係。模型獲得該圖形上隨機遊走的軌跡作為上下文,且必須根據學習到的連接結構預測有效的後續節點。雖然 Park 等人研究了三種不同圖形結構上的圖形追蹤,但我們專注於他們的正方形網格設置(圖 1)。我們在下方提供實驗設置的細節;除了另有說明外,我們的方法論均遵循 Park 等人的研究。

網格結構。 該任務使用由 16 個不同單字標記組成的 $4 \times 4$ 網格:apple, bird, car, egg, house, milk, plane, opera, box, sand, sun, mango, rock, math, code, phone。^() 每個單字佔據網格中的唯一位置。如果兩個單字在水平或垂直方向上相鄰(非對角線),則稱它們為「鄰居」。這定義了一個鄰接矩陣 $A$,其中 $A_{ij} = 1$ 當且僅當單字 $i$ 和 $j$ 是鄰居。

隨機遊走生成。 序列由該網格上的隨機遊走生成:從隨機位置開始,每一步均勻隨機地移動到一個鄰居。這產生了如 apple bird milk sand sun plane opera ... 的序列,其中連續的單字始終是網格鄰居。遵循 Park 等人的做法,我們使用 1400 個標記的序列長度。

測量準確度。 在時間步 $t$,遊走位於節點 $w_t$,其鄰居為 $N(w_t)$,模型輸出一個在詞彙標記上的分佈 $p_\theta(\cdot | w_{1:t})$。遵循 Park 等人的定義,我們將「規則遵循準確度」(rule following accuracy)定義為分配給有效後續節點的總機率質量:
$$\text{Accuracy} = \sum_{a \in N(w_t)} p_\theta(a | w_{1:t})$$

PCA 視覺化。 為了評估模型的表示是否變得類似於網格結構,我們從較後面的層(32 層中的第 26 層)提取激活值。對於 16 個單字中的每一個,我們透過對序列最後 200 個位置中出現的所有該單字取平均值,計算出類別平均激活向量。然後,我們將這 16 個類別平均向量投影到其前兩個主成分(PCA)上進行視覺化。如果表示幾何結構反映了網格,則相鄰的標記在此投影中應出現在附近。

重現與 Park et al. 的解釋

圖 2 顯示了我們在 Llama-3.1-8B 上重現 Park 等人的主要結果。

圖 2. Park et al. 主要結果的重現
左圖: 網格追蹤任務的模型準確度隨上下文長度增加而提升,在約 1000 個標記後達到 $>90%$ 的準確度。陰影區域顯示 16 個隨機序列的標準差。右圖: 在看到 1400 個標記後,第 26 層類別平均激活值的 PCA 投影。灰色虛線連接網格鄰居。有效表示的幾何結構與數據底層的網格結構相似。

Park 等人將這些發現解釋為幾何重組在任務表現中發揮功能性作用的證據:模型在其表示中學習了圖形結構,而這種學習到的結構正是實現準確預測下一個節點的原因。

「我們看到,一旦模型看到臨界數量的上下文,準確度就會開始迅速提高。我們發現這一點實際上與狄利克雷能量(Dirichlet energy)^() 達到最小值時密切吻合:能量在上下文任務準確度迅速增加之前不久降至最低,這表明在模型能夠做出有效預測之前,數據結構已被正確學習。這使我們得出這樣的結論:隨著上下文規模的擴大,表示會出現一種湧現性的重組,使模型能夠在我們的上下文圖形追蹤任務中表現良好。
— Park et al. (Section 4.1; 原文強調)

一個更簡單的解釋:歸納電路

我們提出,網格追蹤任務可以透過比 Park 等人假設的「上下文表示重組」更簡單的機制來解決:歸納電路(induction circuits,; )。

歸納電路由兩類注意力頭協作組成。前一標記頭(Previous-token heads)將注意力從位置 $i$ 轉向位置 $i-1$,將前一個標記的信息複製到當前位置的殘差流中。歸納頭(Induction heads)則關注當前標記先前出現過的位置之後的位置。兩者結合實現了上下文二元語法召回:「如果 $B$ 之前跟在 $A$ 後面,那麼再次看到 $A$ 時,預測 $B$。」^()

在網格任務中,如果模型在序列早期看過二元組 apple bird,那麼在再次遇到 apple 時,歸納電路可以檢索並預測 bird。由於隨機遊走中的連續標記始終是網格鄰居,因此每個召回的後續標記都保證是有效的下一步。有了足夠的上下文,模型將觀察到每個標記的多個後續標記,並可以對這些標記進行聚合,將機率質量分配給所有有效的鄰居。^()

測試歸納假設

如果模型依賴歸納電路來解決任務,那麼消融(ablate)組成這些電路的注意力頭應該會顯著降低任務表現。我們透過零消融(zero ablation)來測試這一點:將目標注意力頭的輸出設置為零,並測量對任務準確度和上下文表示幾何結構的因果影響。

注意力頭識別。 遵循 的方法,我們利用重複序列上的注意力模式分析來識別歸納頭和前一標記頭,並根據各自的分數對 Llama-3.1-8B 中的所有 1024 個頭進行排名,得出兩個排名列表。

消融程序。 對於每種頭類型,我們消融排名前 $k$ 的頭,並測量對任務準確度和表示幾何結構的影響。作為對照組,我們從排除前 32 個歸納頭和前 32 個前一標記頭之外的所有頭中隨機抽取注意力頭進行消融。所有準確度曲線均對 16 個隨機遊走序列(每個網格起始位置一個)取平均值。隨機頭對照組另外對 4 組獨立的 32 個頭取平均值。

結果

圖 3. 注意力頭消融對任務準確度的影響
左圖: 消融排名靠前的歸納頭會逐漸降低準確度,但模型仍能隨上下文學習。右圖: 消融排名靠前的前一標記頭會導致準確度停滯,即使有更多上下文也無法學習。準確度是對 16 個隨機遊走序列取平均值。灰色線顯示消融 32 個隨機頭(排除前幾名歸納頭和前一標記頭)的效果(對 4 組獨立頭取平均值)。

歸納頭和前一標記頭對任務表現都至關重要。 圖 3 顯示了在注意力頭消融下的任務準確度。消融前 4 個歸納頭會使準確度從 $>90%$ 降至 $\sim 60%$,而消融前 32 個歸納頭則使準確度一路降至 $\sim 30%$。僅消融前 2 個前一標記頭就會使準確度降至 $40%$ 以下,消融前 32 個前一標記頭則進一步使準確度降至 $\sim 20%$。

相比之下,消融 32 個隨機頭僅導致輕微退化(準確度保持在 $>85%$),這表明歸納頭和前一標記頭對任務表現特別重要。

雖然這兩類頭對任務表現都很重要,但它們的消融對上下文學習動態有著本質不同的影響。消融歸納頭會降低表現,但準確度仍隨著上下文長度的增加而上升。相反,消融前一標記頭會導致準確度完全停滯。

圖 4. 注意力頭消融對表示幾何結構的影響
不同消融條件下類別平均激活值的 PCA 投影。左圖: 消融前 32 個歸納頭保留了網格幾何結構。右圖: 消融前 32 個前一標記頭破壞了空間組織。這表明前一標記頭對於幾何結構是必要的,而歸納頭則不然。

消融前一標記頭會破壞表示幾何結構。 雖然這兩類頭對準確度都很重要,但它們對表示幾何結構的影響似乎不同。圖 4 顯示,消融歸納頭在 PCA 視覺化中保留了類網格的幾何結構,因為 2D 投影仍然類似於空間網格。然而,消融前一標記頭破壞了這種結構,導致表示失去了明顯的空間組織。

前一標記混合可以解釋表示幾何結構

在上一節中,我們研究了任務表現,並論證了模型透過使用歸納電路來實現高任務準確度。現在我們研究表示幾何結構,並試圖解釋類網格的 PCA 圖。我們將論證,這種結構很可能是前一標記頭執行的「標記混合」(token mixing)的副產品。

鄰居混合假設

圖 4 顯示,消融前一標記頭會破壞網格結構,而消融歸納頭則會保留它。這表明前一標記頭對於幾何組織是必要的。但是,什麼機制能將前一標記頭與空間結構聯繫起來呢?

前一標記頭將位置 $i-1$ 的信息混合到位置 $i$ 中。在隨機遊走中,位置 $i-1$ 的標記始終是位置 $i$ 標記的網格鄰居。因此,每個標記的表示都會與其鄰居的表示混合。當我們計算單字 $w$ 的類別平均值時,我們是對 $w$ 出現的所有位置取平均,其中每個位置都與其前面的鄰居混合。在多次出現後,$w$ 前面出現各個鄰居的機率大致相等,因此 $w$ 的類別平均值大致編碼了 $w$ 加上其鄰居的平均值。

為了測試僅憑鄰居混合是否能產生觀察到的幾何結構,我們構建了一個極簡的玩具模型(toy model)。

前一標記混合的玩具模型

我們直接在由 $4 \times 4$ 網格節點索引的 16 標記空間中工作。每個節點 $i$ 被分配一個初始隨機向量 $x_i$,從 $\mathcal{N}(0, I)$ 中獨立同分佈(i.i.d.)採樣。僅對原始嵌入 $x_i$ 進行 PCA 會產生一個基本上無結構的雲圖:看不到網格的痕跡。

然後我們應用一個單步的「鄰居混合」:
$$\tilde{x}i = x_i + \alpha \sum{j \in N(i)} x_j$$
其中 $N(i)$ 表示節點 $i$ 的鄰居集合。

僅經過這一步,對 16 個混合向量 $\tilde{x}_i$ 進行 PCA 就能恢復出清晰的 $4 \times 4$ 網格:鄰居在 2D 投影中距離較近,而非鄰居則距離較遠(圖 5)。

圖 5. 一輪鄰居混合從隨機嵌入中創造出網格結構
左圖: 16 個隨機高斯向量的 PCA 投影顯示沒有空間結構。右圖: 在應用一次鄰居混合步驟後,相同的嵌入在 PCA 空間中展現出清晰的網格組織。灰色虛線連接網格鄰居。

個別模型激活值中鄰居混合的證據

鄰居混合假設提出了一個進一步的預測:個別激活值應不僅反映當前標記,還應反映其前一個標記。

我們不將每個單字折疊成單一的類別平均值,而是取長度為 1400 的隨機遊走序列的最後 200 個位置,並將這 200 個殘差流向量投影到用於類別平均值的同一個 2D PCA 空間中。現在每個點對應一個特定的激活值。對於每個點,我們顯示二元組信息:中心顏色表示當前標記 $w_t$,邊框顏色表示前一個標記 $w_{t-1}$。

圖 6. 二元組級別的 PCA 視覺化
每個點代表單個位置的激活值。填充顏色表示當前標記;邊框顏色表示前一個標記。具有相同當前標記但不同前一標記的點形成了明顯的簇,表明該表示編碼了兩者的信息。星形標記顯示標記質心。

個別激活值似乎帶有前一標記混合的指紋(圖 6)。例如,在二元組 plane math 出現的位置,激活值往往位於 plane 和 math 質心之間;而在 egg math 出現的位置,激活值則傾向於位於 egg 和 math 質心之間。我們在所有其他二元組中都看到了類似的「中間」行為。如果 $w_t$ 的表示包含類似「自身」與「前一標記」的混合,而不僅僅取決於當前單字,這正是人們所預期的結果。

局限性

我們的實驗指向一個簡單的解釋:模型透過歸納電路執行上下文圖形追蹤,而類網格的 PCA 幾何結構是前一標記混合的副產品。然而,我們在某些重要方面仍缺乏完整的理解。

玩具模型是顯著的簡化。 我們的鄰居混合規則假設前一標記頭只是將前一個標記的激活值 $x_{t-1}$ 加到當前標記的激活值 $x_t$ 上。實際上,注意力頭會應用值(value)和輸出(output)投影:它們增加的是 $W_O W_V x_{t-1}$,其中 $W_O W_V$ 是一個低秩矩陣(秩 $\le d_{head}$)。這種投影可能會大幅轉換被混合的信息,且值得注意的是,它無法實現恆等映射(至少單個頭不行),因為它是低秩的。我們還將所有內容建模為對靜態向量進行單次混合,而實際網絡擁有許多注意力頭、MLP 塊和多個層,會反覆轉換殘差流。

為什麼網格結構在序列後期才出現? 前一標記頭從序列開始就處於活動狀態,但類網格的 PCA 結構僅在處理了許多標記後才變得清晰可見。如果鄰居混合就是故事的全貌,我們可能會預期幾何結構會更早出現。 開發了一個理論框架,將跨上下文和層級的類圖卷積過程形式化,這可能對幾何結構如何湧現提供更完整的說明。

僅限於上下文網格追蹤任務。 我們的分析僅限於 Park 等人的 $4 \times 4$ 網格隨機遊走任務,在該任務中,二元組複製足以進行下一個標記預測。 同期發現,在這些隨機遊走任務中,上下文表示在很大程度上是「惰性」的——模型編碼了圖形拓撲,但難以將其應用於下游空間推理。然而,在其他設置中,上下文表示的變化可能更具功能性: 顯示上下文範例可以功能性地覆蓋標記的語義。研究歸納電路不足以應對的更複雜上下文學習任務(例如具有層次結構或依賴上下文結構的任務,)也將會很有趣。

結論

我們認為 觀察到的現象可以透過語言模型中廣為人知的機制來解釋。上下文圖形追蹤的任務表現可以由歸納電路很好地解釋,該電路負責召回先前看過的二元組。PCA 圖中可見的幾何組織似乎是前一標記混合的副產品:因為隨機遊走穿越圖形邊緣,前一標記頭將每個位置的表示與圖形鄰居的表示混合,而這種混合本身就足以從無結構的嵌入中產生類網格結構。

這些發現表明,Park 等人觀察到的「表示重組」可能並不反映一種複雜的上下文學習策略,而更像是前一標記頭行為的人為產物。


  • [blocked]^() 所有單字在前面加上空格時,都會被精確地標記為一個標記(例如,「 apple」是一個單一標記)。序列在第一個單字前以空格開始標記化,確保每個單字採用單標記編碼。
  • [blocked]^() 狄利克雷能量測量信號在圖形邊緣上的變化程度。低能量意味著相鄰節點具有相似的表示,因此 Park 等人用它來量化模型的表示在多大程度上遵循圖形結構。
  • [blocked]^() 在文獻中,「歸納頭」一詞有時被用來同時指代單個注意力頭和完整的雙組件電路。我們使用「歸納電路」指代完整的機制,而用「歸納頭」指代關注先前出現位置之後標記的特定頭,以避免歧義。
  • [blocked]^() 例如,如果模型看過 apple bird 和 apple house,它可以在預測 apple 之後的下一個標記時,將機率分配給 bird 和 house。