突然火起來的diffusion model是什麼?

語言: CN / TW / HK

theme: arknights

我正在參加「掘金·啟航計劃」


模型起源

2015年的時候,有幾位大佬基於非平衡熱力學提出了一個純數學的生成模型 (Sohl-Dickstein et al., 2015)。不過那個時候他們沒有用程式碼實現,所以這篇工作並沒有火起來。

直到後來斯坦福大學(Song et al., 2019) 和谷歌大腦 (Ho et al., 2020) 有兩篇工作延續了15年的工作。再到後來2020年穀歌大腦的幾位大佬又把這個模型實現了出來(Ho et al., 2020),因為這個模型一些極其優秀的特性,所以它現在火了起來。

擴散模型可以做什麼?呢它可以做一些。條件生成和非條件生成。在影象、語音、文字三個方向都已經有了一些應用,並且效果比較突出。

比較出圈的工作有我剛介紹的text to image的生成工作比如

什麼是擴散模型?

Diffusion model 和 Normalizing Flows, GANs or VAEs 一樣,都是將噪聲從一些簡單的分佈轉換為一個數據樣本,也是神經網路學習從純噪聲開始逐漸去噪資料的過程。 包含兩個步驟: - 一個我們選擇的固定的(或者說預定義好的)前向擴散過程 $q$ ,就是逐漸給圖片新增高斯噪聲,直到最後獲得純噪聲。

  • 一個需要學習的反向的去噪過程 $p_\theta$,訓練一個神經網做影象去噪,從純噪聲開始,直到獲得最終影象。

image.png

前向和反向過程都要經過時間步$t$,總步長是$T$(DDPM中$T=1000$)。

你從$t=0$開始,從資料集分佈中取樣一個真實圖片$x_0$。比如你用cifar-10,用cifar-100,用ImageNet,總之就是從你資料集裡隨機取樣一張圖片作為$x_0$。

前向過程就是在每一個時間步$t$中都從一個高斯分佈中取樣一個噪聲,將其新增到上一時間步的影象上。給出一個足夠大的$T$,和每一時間步中新增噪聲的表格,最終在$T$時間步你會獲得一個isotropic Gaussian distribution

我要開始上公式了!

我們令$q(x_0)$是真實分佈,也就是真實的影象的分佈。

我們可以從中取樣一個圖片,也就是$x_0 \sim q(x_0)$ 。

我們設定前向擴散過程$q(x_t|x_{t-1})$是給每個時間步$t$新增高斯噪聲,這個高斯噪聲不是隨機選擇的,是根據我們預選設定好的方差表($0 < \beta_1 < \beta_2 < ... < \beta_T < 1$)的高斯分佈中獲取的。

然後我們就可以得到前向過程的公式為: $$ q({x}t | {x}{t-1}) = \mathcal{N}({x}t; \sqrt{1 - \beta_t} {x}{t-1}, \beta_t \mathbf{I}). $$

$$ \mathcal{N}({x}t; \sqrt{1 - \beta_t} {x}{t-1}, \beta_t \mathbf{I}) 就是{x}t \sim \mathcal{N}( \sqrt{1 - \beta_t} {x}{t-1}, \beta_t \mathbf{I}). $$

回想一下哦。一個高斯分佈(也叫正態分佈)是由兩個引數決定的,均值$\mu$和方差$\sigma^2 \geq 0$。

然後我們就可以認為每個時間步$t$的影象是從一均值為${\mu}t = \sqrt{1 - \beta_t} {x}{t-1}$、方差為$\sigma^2_t = \beta_t$的條件高斯分佈中畫出來的。藉助引數重整化(reparameterization trick)可以寫成

$$ {x}t = \sqrt{1 - \beta_t}{x}{t-1} + \sqrt{\beta_t} \mathbf{\epsilon} $$

其中$\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$,是從標準高斯分佈中取樣的噪聲。

$\beta_t$在不用的時間步$t$中不是固定的,因此我們給$\beta$加了下標。對於$\beta_t$的選擇我們可以設定為線性的、二次的、餘弦的等(有點像學習率計劃)。

比如在DDPM中$\beta_1 = 10^{-4}$,$\beta_T = 0.02$,在中間是做了一個線性插值。而在Improved DDPM中是使用餘弦函式。

從$x_0$開始,我們通過$\mathbf{x}_1, ..., \mathbf{x}_t, ..., \mathbf{x}_T$,最終獲得${x}_T$ ,如果我們的高斯噪聲表設定的合理,那最後我們獲得的應該是一個純高斯噪聲。

現在,如果我們能知道條件分佈$p({x}_{t-1} | {x}_t)$,那我們就可以將這個過程倒過來:取樣一個隨機高斯噪聲$x_t$,我們可以對其逐步去噪,最終得到一個真實分佈的圖片$x_0$。

但是我們實際上沒辦法知道$p({x}{t-1} | {x}_t)$。因為它需要知道所有可能影象的分佈來計算這個條件概率。因此,我們需要藉助神經網路來近似(學習)這個條件概率分佈。 也就是$p\theta ({x}_{t-1} | {x}_t)$,其中, $\theta$是神經網路的引數,需要使用梯度下降更新。

所以現在我們需要一個神經網路來表示逆向過程的(條件)概率分佈。如果我們假設這個反向過程也是高斯分佈,那麼回想一下,任何高斯分佈都是由兩個引數定義的:

  • 一個均值$\mu_\theta$;
  • 一個方差$\Sigma_\theta$。

所以我們可以把這個過程引數化為

$$ p_\theta (\mathbf{x}{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}{t-1}; \mu_\theta(\mathbf{x}{t},t), \Sigma\theta (\mathbf{x}_{t},t)) $$

其中均值和方差也取決於噪聲水平$t$。

從上邊我們可以知道,逆向過程我們需要一個神經網路來學習(表示)高斯分佈的均值和方差。

帶DDPM中作者固定方差,只讓神經網路學習條件概率分佈的均值。

First, we set $\Sigma_\theta ( \mathbf{x}_t, t) = \sigma^2_t \mathbf{I}$ to untrained time dependent constants. Experimentally, both $\sigma^2_t = \beta_t$ and $\sigma^2_t = \tilde{\beta}_t$ (see paper) had similar results.

之後再Improved diffusion models這篇文章中進行了改進,神經網路既需要學習均值也要學習方差。

通過重新引數化平均值定義目標函式

為了推匯出一個目標函式來學習逆向過程的均值,作者觀察到$q$和$p_\theta$可以看做是一個VAE模型 (Kingma et al., 2013).

因此,變分下界(ELBO)可以用來最小化關於ground truth$x_0$的負對數似然。

這個過程的ELBO是每個時間步$t$的損失總,$L=L_0+L_1+…+L_𝑇$。

通過構建正向𝑞過程和反向過程,損失的每一項(除了$L_0$)是兩個高斯分佈之間的KL散度,可以明確地寫為關於平均值的$L_2$損失!

因為高斯分佈的特性,我們不需要在正向$q$過程中迭代$t$步就可以獲得$x_t$的結果:

$$ q({x}_t | {x}_0) = \cal{N}({x}_t; \sqrt{\bar{\alpha}_t} {x}_0, (1- \bar{\alpha}_t) \mathbf{I}) $$

其中$\alpha_t := 1 - \beta_t$ and $\bar{\alpha}t := \Pi{s=1}^{t} \alpha_s$。

這是一個很優秀的屬性。這意味著我們可以對高斯噪聲進行取樣並適當縮放直接將其新增到$x_0$中就可以得到$x_t$。請注意,$\bar{\alpha}_t$是方差表$\beta_t$的函式,因此也是已知的,我們可以對其預先計算。這樣可以讓我們在訓練期間優化損失函式$L$的隨機項(換句話說,在訓練期間隨機取樣$t$就可以優化$L_t$)。

這個屬性的另一個優美之處this excellent blog post) 在於對均值進行引數重整化,使神經網路學習(預測)新增的噪聲(通過網路$\epsilon_\theta(x_t,t)$),在KL項中構成損失的噪聲級別$t$。這意味著我們的神經網路變成了噪聲預測器,而不是直接去預測均值了。均值的計算方法如下:

$$ {\mu}\theta({x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( {x}_t - \frac{\beta_t}{\sqrt{1- \bar{\alpha}_t}} {\epsilon}\theta({x}_t, t) \right)$$

最後的目標函式$L_t$ 長這樣,給定隨機的時間步 $t$ 使${\epsilon} \sim \mathcal{N}({0}, {I})$ ):

$$ \| {\epsilon} - {\epsilon}\theta({x}_t, t) \|^2 = \| {\epsilon} - {\epsilon}\theta( \sqrt{\bar{\alpha}_t} {x}_0 + \sqrt{(1- \bar{\alpha}_t) } {\epsilon}, t) \|^2.$$

其中$x_0$是初始影象,我們看到噪聲$t$樣本由固定的前向過程給出。$\epsilon$是在時間步長$t$取樣的純噪聲,$\epsilon_\theta(x_t,t)$是我們的神經網路。神經網路的優化使用一個簡單的均方誤差(MSE)之間的真實和預測高斯噪聲。 訓練演算法如下

image.png

  1. 從位置且複雜的真實資料分佈$q(x_0)$中隨機取樣$x_0$,
  2. 我們在1和$T$之間均勻採不同時間步的噪聲,
  3. 我們從高斯分佈取樣一些噪聲,並在$𝑡$時間步上使用前邊定義的優良屬性來破壞輸入分佈,
  4. 神經網路根據損壞的影象$x_t$進行訓練,目的是預測施加在圖片上的噪聲,也就是基於已知方差表$\beta_t$作用在$x_0$上的噪聲

所有這些都是在批量資料上完成的,和使用隨機梯度下降法來優化神經網路一樣。