Keras深度學習——生成對抗網路
theme: hydrogen
持續創作,加速成長!這是我參與「掘金日新計劃 · 6 月更文挑戰」的第27天,點選檢視活動詳情
前言
生成對抗網路 (Generative Adversarial Networks
, GAN
) 使用神經網路生成與原始影象集非常相似的新影象,它在影象生成中應用廣泛,且 GAN
的相關研究正在迅速發展,以偽造生成與真實影象難以區分的逼真影象。在本節中,我們將學習 GAN
網路的原理並使用 Keras
實現 GAN
。
生成對抗網路詳解
GAN
包含兩個網路:生成器和鑑別器。生成器的目標是生成逼真的影象騙過鑑別器,鑑別器的目標是確定輸入影象是真實影象還是生成器生成的偽造影象。
假設 GAN
用於生成人臉影象,鑑別器試圖將圖片分類為真實人臉影象或者偽造的虛假人臉影象,一旦我們訓練完成的鑑別器能夠將正確分類真實人臉影象和虛假人臉影象,如果我們向鑑別器輸入新的人臉圖片,它能夠將輸入圖片分類為真實人臉影象和虛假人臉影象。生成器的任務是生成看起來與原始影象集非常相似的人臉影象,以至於鑑別器會誤以為所生成的影象來自原始資料集。
接下來,我們詳細介紹 GAN
生成影象的網路策略:
- 使用生成器生成偽造影象,生成器在最初只能生成噪聲影象,噪聲影象是通過將一組噪聲值通過權重隨機的神經網路得到的影象
- 將生成的影象與原始影象串聯起來,鑑別器預測每個影象是偽造影象還是真實影象,對鑑別器進行訓練:
- 在迭代中訓練鑑別器權重
- 鑑別器的損失是影象的預測值和實際值(標籤)的二進位制交叉熵
- 生成的偽造影象的實際值(標籤)為 0
,原始資料集中真實影象的實際值(標籤)為 1
- 對鑑別器進行一次迭代訓練後,就可以訓練生成器利用輸入噪聲生成偽造影象,使其看起來更接近真實影象,從而使生成影象有可能欺騙鑑別器:
- 輸入噪聲通過生成器傳遞,通過多個隱藏層後,生成器最後輸出偽造影象
- 將生成器生成的影象輸入到鑑別器中,需要注意的是,鑑別器權重在此迭代訓練中被凍結,因此在此迭代中不對其進行訓練
- 在此訓練過程中,因為生成器的目標是欺騙鑑別器,因此,假設生成的虛假影象實際值(標籤)為 1
- 生成器的損失是鑑別器對輸入影象的預測值和實際值 (1
) 的二進位制交叉熵:
- 此步驟中凍結鑑別器權重,凍結鑑別器可確保生成器從鑑別器提供的輸出反饋中進行學習
- 重複以上過程,直到生成逼真的影象
利用生成對抗網路生成手寫數字影象
在本節中,我們採用 Keras
實現 GAN
,並使用 MNIST
資料集訓練 GAN
生成手寫數字影象。
首先,匯入相關庫,並定義超引數:
```python
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Reshape, Flatten
from keras.models import Sequential
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from keras.layers import BatchNormalization, LeakyReLU
shape = (28, 28, 1)
epochs = 5000
batch_size = 64
save_interval = 100
接下來,定義生成器,對於生成器模型,其採用形狀為 `100` 維的噪聲向量,通過數個全連線層後生成 `28×28×1=1024` 的向量,最後將其整形為形狀為 `(28, 28, 1)` 的影象,在模型中使用 `LeakyReLU` 啟用函式。:
python
def generator():
model = Sequential()
model.add(Dense(256, input_shape=(100,)))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(28281, activation='tanh'))
model.add(Reshape(shape))
return model
生成器的簡要資訊輸出如下:
shell
Model: "sequential"
Layer (type) Output Shape Param #
dense (Dense) (None, 256) 25856
leaky_re_lu (LeakyReLU) (None, 256) 0
batch_normalization (BatchNo (None, 256) 1024
dense_1 (Dense) (None, 512) 131584
leaky_re_lu_1 (LeakyReLU) (None, 512) 0
batch_normalization_1 (Batch (None, 512) 2048
dense_2 (Dense) (None, 1024) 525312
leaky_re_lu_2 (LeakyReLU) (None, 1024) 0
batch_normalization_2 (Batch (None, 1024) 4096
dense_3 (Dense) (None, 784) 803600
reshape (Reshape) (None, 28, 28, 1) 0
Total params: 1,493,520 Trainable params: 1,489,936 Non-trainable params: 3,584
接下來,我們將構建鑑別器模型,該模型將形狀為 `(28, 28, 1)` 的輸入影象,併產生輸出 `1` 或 `0`,用於表示輸入影象是原始真實影象還是生成的偽造影象:
python
def discriminator():
model = Sequential()
model.add(Flatten(input_shape=shape))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
return model
鑑別器模型的簡要結構資訊輸出如下:
shell
Model: "sequential_1"
Layer (type) Output Shape Param #
flatten (Flatten) (None, 784) 0
dense_4 (Dense) (None, 1024) 803840
leaky_re_lu_3 (LeakyReLU) (None, 1024) 0
dense_5 (Dense) (None, 256) 262400
leaky_re_lu_4 (LeakyReLU) (None, 256) 0
dense_6 (Dense) (None, 1) 257
Total params: 1,066,497 Trainable params: 1,066,497 Non-trainable params: 0
編譯生成器和鑑別器模型:
python
generator = generator()
generator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5, decay=8e-8))
discriminator = discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5, decay=8e-8), metrics=['acc'])
組合生成器與鑑別器,定義 `GAN` 模型,該模型用於訓練生成器的權重,同時凍結鑑別器的權重。`GAN` 模型將隨機噪聲作為輸入,並使用生成器網路將該噪聲轉換為形狀為 `(28, 28, 1)` 的影象,然後模型預測生成的影象是真實影象還是偽造影象:
python
def gan(discriminator, generator):
discriminator.trainable = False
model = Sequential()
model.add(generator)
model.add(discriminator)
return model
gan = gan(discriminator, generator)
gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5, decay=8e-8))
定義函式用於繪製生成的影象:
python
def plot_images(samples=16, step=0):
noise = np.random.normal(0, 1, (samples, 100))
images = generator.predict(noise)
plt.figure(figsize=(10, 10))
for i in range(images.shape[0]):
plt.subplot(4, 4, i + 1)
image = images[i, :, :, :]
image = np.reshape(image, [28, 28])
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
載入 `MNIST` 資料集,並對資料集進行預處理:
python
(x_train, ), (, _) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=3)
因為 `GAN` 模型基於給定的影象集 `x_train` 生成新影象,因此我們不需要輸出標籤。
接下來,通過在多個 `epochs` 內訓練 `GAN` 來優化網路權重。
獲取真實影象 `legit_images` 並利用噪聲資料生成偽造影象 `synthetic_images`,使用噪聲資料 `gen_noise` 作為輸入,嘗試生成真實影象:
python
disc_loss = []
gen_loss = []
for cnt in range(epochs):
random_index = np.random.randint(0, len(x_train) - batch_size / 2)
legit_images = x_train[random_index: random_index + batch_size // 2].reshape(batch_size // 2, 28, 28, 1)
gen_noise = np.random.normal(-1, 1, (batch_size // 2, 100))/2
synthetic_images = generator.predict(gen_noise)
使用 `train_on_batch` 方法訓練鑑別器,`train_on_batch` 用於使用單個批資料對模型執行一次梯度更新,在輸出中,實際影象的值為 `1`,偽造影象的值為 `0`:
python
x_combined_batch = np.concatenate((legit_images, synthetic_images))
y_combined_batch = np.concatenate((np.ones((batch_size // 2, 1)), np.zeros((batch_size // 2, 1))))
d_loss = discriminator.train_on_batch(x_combined_batch, y_combined_batch)
接下來,我們準備用於訓練生成器的資料,隨機噪聲作為輸入資料 `noise`,而 `y_mislabeled` 是用於訓練生成器的輸出,需要注意的是,這裡的輸出與訓練鑑別器時的輸出完全相反,即使用 `1` 作為偽造影象的值:
python
noise = np.random.normal(-1, 1, (batch_size, 100))/2
y_mislabled = np.ones((batch_size, 1))
接下來,我們訓練 `GAN` 模型,其中鑑別器權重被凍結,而生成器的權重會得到更新以最小化損失,生成器的任務是生成可欺騙鑑別器的影象,即令鑑別器輸出值 `1`:
python
g_loss = stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
然後,我們記錄各個 `epoch` 內的生成器損失和鑑別器損失,並按照指定間隔檢視生成器生成影象:
python
g_loss = gan.train_on_batch(noise, y_mislabled)
disc_loss.append(d_loss[0])
gen_loss.append(g_loss)
print('epoch: {}, [Discriminator: {}], [Generator: {}]'.format(cnt, d_loss[0], g_loss))
if cnt % 100 == 0:
plot_images(step=cnt)
```
在人眼看來,生成的影象仍然並不真實,因此模型仍具有很大的改進空間,我們將在之後的學習中介紹能夠生成更加逼真影象的 GAN
架構。
最後,繪製 GAN
訓練過程中的損失變化情況,隨著訓練 epoch
的增加,鑑別器損失和生成器損失的變化如下:
python
epochs = range(1, epochs+1)
plt.plot(epochs, disc_loss, 'bo', label='Discriminator loss')
plt.plot(epochs, gen_loss, 'r', label='Generator loss')
plt.title('Generator and Discriminator loss values')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()
- OpenCV使用顏色進行膚色檢測
- Keras深度學習——構建電影推薦系統
- PyTorch張量操作詳解
- OpenCV直方圖的比較
- Python 常用字串操作
- 使用 dlib 進行人臉識別
- OpenCV 人臉檢測詳解(僅需2行程式碼學會人臉檢測)
- Keras深度學習——使用fastText構建單詞向量
- Keras深度學習——使用skip-gram和CBOW模型構建單詞向量
- Keras深度學習——從零開始構建單詞向量
- Keras深度學習——生成對抗網路
- Keras深度學習——建立自定義目標檢測資料集
- PyTorch強化學習——基於值迭代的強化學習演算法
- PyTorch強化學習——模擬FrozenLake環境
- PyTorch強化學習——策略評估
- PyTorch強化學習——馬爾科夫決策過程
- Keras深度學習——DeepDream演算法生成影象
- Keras深度學習——使用對抗攻擊生成可欺騙神經網路的影象
- PyTorch強化學習——策略梯度演算法
- Keras深度學習——交通標誌識別