PRX 第 3 部分 — 在 24 小時內訓練一個文本生成圖像模型!
PRX 第 3 部分 — 在 24 小時內訓練一個文本生成圖像模型!
前言
歡迎回來 👋
在過去的兩篇文章(第 1 部分和第 2 部分)中,我們探索了擴散模型中廣泛的架構和訓練技巧。我們嘗試孤立地評估每個想法,測量吞吐量、收斂速度和最終圖像質量,並試圖了解究竟是什麼在起作用。
在這篇文章中,我們想要回答一個更具實踐性的問題:
當我們把所有有效的技巧結合在一起時會發生什麼?
我們不再一次優化一個維度,而是將最有前景的成分堆疊在一起,看看在嚴格的計算預算下能將性能推到多遠。
具體來說,我們正在進行一場 24 小時的競速:
這與早期的擴散模型時代相去甚遠,當時訓練具競爭力的模型可能耗資數百萬美元。這裡的目標是展示該領域進化了多少,以及精細的工程設計在僅僅一天的訓練中能帶你走多遠。
這次競速不僅僅是一個有趣的實驗。它很可能成為我們未來大規模訓練方案的基礎。
除了結果之外,我們還開源了我們的代碼 (Github),其中包含:
因此你可以自行重現、修改和擴展一切。
訓練方案
現在讓我們來看看這次 24 小時運行中包含了哪些內容。
X-預測與像素空間訓練
我們使用了來自《Back to Basics: Let Denoising Generative Models Denoise》[Li and He, 2025] 的 x-預測公式。正如在第 2 部分中所見,這使得直接在像素空間進行訓練成為可能,並完全消除了對 VAE 的需求。
我們使用 32 的 patch 大小,並在初始 token 投影層中使用 256 維的瓶頸。這種設計保持了序列長度在可控範圍內,使得像素空間訓練即使在高解析度下在計算上也是可行的。
在 512px 時,序列長度為:
(512/32)^2 = 256
在 1024px 時,序列長度變為:
(1024/32)^2 = 1024
我們沒有遵循通常的 256px → 512px → 1024px 時程,而是直接從 512px 開始,然後在 1024px 進行微調。
憑藉受控的 token 數量和現代硬體,像素空間訓練不再是高不可攀。它只是一個更簡潔、更直接的公式。
感知損失 (Perceptual Losses)
直接在像素空間預測 $x_0$ 的一個非常好的副作用是,我們可以重新利用經典計算機視覺中的整套工具箱。
當你的模型輸出潛變量(latents)時,感知監督會變得尷尬。你麼必須解碼回像素,要麼在學習到的潛空間中定義損失,而這可能與人類感知一致,也可能不一致。一旦你直接預測像素,一切都變得簡單了。你可以完全按照感知損失最初的設計來插入它們。
我們從論文《PixelGen: Pixel Diffusion Beats Latent Diffusion with Perceptual Loss》[Ma et al.] 中汲取靈感,作者在擴散損失之上引入了額外的感知目標。他們表明,增加感知信號可以顯著提高收斂速度和最終的視覺質量。
對於這次 24 小時運行,我們增加了兩個輔助損失:
這個想法很簡單:除了標準的 flow matching 目標外,我們鼓勵預測的清晰圖像在感知特徵空間中與目標圖像匹配。LPIPS 捕捉低級感知相似性,而 DINO 特徵則提供更強的語義信號。
我們保留了與論文相同的總體思路,但調整了一些細節。在我們的實驗中,我們憑經驗發現這樣做效果更好:
這些是微小的實現細節,但在我們的設置中,它們始終能提供更好的結果。
我們對 LPIPS 損失使用 0.1 的權重,對 DINO 感知損失使用 0.01 的權重,與原始論文中推薦的數值一致。
與主要的 transformer 前向傳播相比,這些損失是輕量級的,在我們的設置中,它們僅增加了一點點開銷,同時提供了穩定的質量提升。
使用 TREAD 進行 Token 路由
為了讓每一步的成本更低,我們使用了 TREAD [Krause et al., 2025] 的 token 路由技術,它隨機選擇一部分 token 並讓它們繞過一組連續的 transformer 區塊,然後在稍後重新注入,這樣就不會丟棄任何內容。
我們選擇 TREAD 而非 SPRINT (Park et al., 2025) 主要是為了簡單起見,而且在我們的設置中(512px 下 TREAD 的序列長度為 64 對比 128),SPRINT 額外的複雜性感覺不值得那相當小的額外計算節省。
遵循 TREAD 的方案,我們將 50% 的 token 從 transformer 的第 2 個區塊路由到倒數第二個區塊。
在普通 CFG 下,路由模型的表現可能較差,尤其是在訓練不足時,因此我們實現了一個受《Guiding Token-Sparse Diffusion Models》 (Krause et al., 2025) 啟發的簡單自我引導方案,該方案使用密集預測與路由條件預測進行引導,而不是依賴無條件分支。
使用 REPA 和 DINOv3 進行表示對齊
我們使用 REPA [Yu et al., 2024] 進行表示對齊。
對於教師模型,我們選擇了 DINOv3 [Siméoni et al. 2025],因為它在我們之前的實驗中提供了最佳的質量改進。
具體來說,我們在第 8 個 transformer 區塊應用一次對齊損失,損失權重為 0.5。
由於我們將 REPA 與 TREAD 路由結合使用,我們僅在非路由 token 上計算對齊損失,即那些真正經過我們應用損失的區塊的 token。這保持了 REPA 信號的一致性,並避免了對跳過計算路徑的 token 進行特徵比較。
優化器:Muon
我們使用了 Muon 優化器,採用來自 muon_fsdp_2 的 FSDP 實現,因為它在我們之前的運行中顯示出比 Adam 明顯的改進。
Muon 僅應用於 2D 參數(基本上是矩陣)。其他所有內容(偏置、歸一化、嵌入等)都使用 Adam 進行優化,這就是為什麼配置中有兩個參數組的原因。
訓練設置
我們在三個公開可用的合成數據集上進行了訓練:
時程基本上是:在 512px 快速進行,然後在 1024px 銳化:
我們還保留了權重的 EMA 用於採樣和評估:
結果與總結
以下是我們在整個運行過程中追蹤的評估曲線,以及來自最終檢查點的一些樣本網格:
對於為期一天的訓練運行來說,這已經是一個相當紮實的成果。模型還不完美(你仍然可以發現一些紋理瑕疵、偶爾奇怪的解剖結構,並且在非常困難的提示詞上可能會有些不穩定),但它顯然是可用的。提示詞遵循能力很強,整體美感一致,1024 階段基本上達到了我們的預期:在不破壞構圖的情況下銳化細節。
關鍵的啟示是我們已經非常接近了。剩餘的問題看起來更像是訓練不足的痕跡和有限的數據多樣性,而不是方案中結構性缺陷的跡象。失敗模式與你對一個尚未見過足夠多樣化數據的模型的預期一致。隨著更多計算資源和更廣泛的覆蓋範圍,這種完全相同的設置應該會以相當可預測的方式繼續改進。
放大來看,這次競速也突顯了擴散模型訓練已經取得了多大的進步。通過結合像素空間訓練、高效路由、表示對齊和輕量級感知引導,你現在可以在大約一天的時間內,以不久前聽起來還不切實際的預算獲得一個有意義的模型。
下一步是什麼?
這次 24 小時運行只是一個起點,而不是終點。接下來,我們將繼續以更大的規模推動相同的方案,並迭代數據集混合和標註(captioning)。
這次競速背後的所有代碼和配置,以及在第 1 部分和第 2 部分中使用的完整實驗框架,都可以在 PRX 存儲庫中找到:https://github.com/Photoroom/PRX。
雖然我們不重新分發本次運行中使用的確切訓練數據集,但該流水線是完全可配置的,旨在輕鬆適應你自己的數據。你可以插入不同的數據集,調整單個組件(TREAD、REPA、感知損失、Muon 等),並以最小的摩擦進行受控實驗。我們的目標是使其成為快速擴散模型研究的實用遊樂場,我們希望社區能利用它在自己的設置中探索、基準測試和迭代這些技術。
如果你讀到了這裡,感謝你的閱讀。我們也歡迎你加入我們的 Discord 社區,在那裡我們分享 PRX 的進展和結果,並討論任何與擴散模型和文本生成圖像相關的話題。
暫時再見,敬請期待下一輪實驗! 🚀
社區
· 註冊或登錄以發表評論