對抗生成網路GAN系列——AnoGAN原理及缺陷檢測實戰
theme: fancy
本文為稀土掘金技術社群首發簽約文章,14天內禁止轉載,14天后未獲授權禁止轉載,侵權必究!
🍊作者簡介:禿頭小蘇,致力於用最通俗的語言描述問題
🍊往期回顧:對抗生成網路GAN系列——GAN原理及手寫數字生成小案例 對抗生成網路GAN系列——DCGAN簡介及人臉影象生成案例 對抗生成網路GAN系列——CycleGAN簡介及圖片春冬變換案例
🍊近期目標:寫好專欄的每一篇文章
🍊支援小蘇:點贊👍🏼、收藏⭐、留言📩
對抗生成網路GAN系列——AnoGAN原理及缺陷檢測實戰
寫在前面
隨著深度學習的發展,已經有很多學者將深度學習應用到物體瑕疵檢測中,如列車鋼軌的缺陷檢測、醫學影像中各種疾病的檢測。但是瑕疵檢測任務幾乎都存在一個共同的難題——缺陷資料太少了。我們使用這些稀少的缺陷資料很難利用深度學習訓練一個理想的模型,往往都需要進行資料擴充,即通過某些手段增加我們的缺陷資料。 【資料擴充大家感興趣自己去了解下,GAN網路也是實現資料擴充的主流手段】 上面說到的方法是基於缺陷資料來訓練的,是有監督的學習,學者們在漫長的研究中,考慮能不能使用一種無監督的方法來實現缺陷檢測呢?於是啊,AnoGAN就橫空出世了,它不需要缺陷資料進行訓練,而僅使用正常資料訓練模型,關於AnoGAN的細節後文詳細介紹。
關於GAN網路,我已經介紹了幾篇,如下:
- [1]對抗生成網路GAN系列——GAN原理及手寫數字生成小案例 🍁🍁🍁
- [2]對抗生成網路GAN系列——DCGAN簡介及人臉影象生成案例🍁🍁🍁
- [3]對抗生成網路GAN系列——CycleGAN原理🍁🍁🍁
在閱讀本文之前建議大家對GAN有一定的瞭解,可以參考[1]和[2],關於[3]感興趣的可以看看,本篇文章用不到[3]相關知識。
準備好了嘛,我們開始發車了喔。🚖🚖🚖
AnoGAN 原理詳解✨✨✨
首先我們來看看AnoGAN的全稱,即Anomaly Detection with Generative Adversarial Networks
,中文是指使用生成對抗網路實現異常檢測。這篇論文解決的是醫學影像中疾病的檢測,由於對醫學相關內容不瞭解,本文將完全將該演算法從論文中剝離,只介紹演算法原理,而不結合論文進行講述。想要了解論文詳情的可以點選☞☞☞檢視。
接下來就隨我一起來看看AnoGAN的原理。其實AnoGAN的原理是很簡單的,但是我看網上的資料總是說的摸稜兩可,我認為主要原因有兩點:其一是沒有把AnoGAN的原理分步來敘述,其二是有專家視角,它們認為我們都應該明白,但這對於新手來說理解也確實是有一定難度的。
在介紹AnoGAN的具體原理時,我先來談談AnoGAN的出發點,這非常重要,大家好好感受。我們知道,DCGAN是將一個噪聲或者說一個潛在變數對映成一張圖片,在我們訓練DCGAN時,都是使用某一種資料進行的,如[2]中使用的資料都是人臉,那麼這些資料都是正常資料,我們從一個潛在變數經DCGAN後生成的圖片應該也都是正常影象。AnoGAN的想法就是我能否將一張圖片M對映成某個潛在變數呢,這其實是較難做到的。但是我們可以在某個空間不斷的查詢一個潛在變數,使得這個潛在變數生成的圖片與圖片M儘可能接近。這就是AnoGAN的出發點,大家可能還不明白這麼做的意義,下文為大家詳細介紹。☘☘☘
AnoGAN其實是分兩個階段進行的,首先是訓練階段,然後是測試階段,我們一點點來看:
-
訓練階段
訓練階段僅使用正常的資料訓練對抗生成網路。如我們使用手寫數字中的數字8作為本階段的資料進行訓練,那麼8就是正常資料。訓練結束後我們輸入一個向量z,生成網路會將z變成8。不知道大家有沒有發現其實這階段就是[2]中的DCGAN呢? 【注意:訓練階段已經訓練好GAN網路,後面的測試階段GAN網路的權重是不在變換的】
-
測試階段
在訓練階段我們已經訓練好了一個GAN網路,在這一階段我們就是要利用訓練好的網路來進行缺陷檢測。如現在我們有一個數據6,此為缺陷資料 【訓練時使用8進行訓練,這裡的6即為缺陷資料】 。現在我們要做的就是搜尋一個潛在變數並讓其生成的圖片與圖片6儘可能接近,具體實現如下:首先我們會定義一個潛在變數z,然後經過剛剛訓練的好的生成網路,得到假影象G(z),接著G(z)和缺陷資料6計算損失,這時候損失往往會比較大,我們不斷的更新z值,會使損失不斷的減少,在程式中我們可以設定更新z的次數,如更新500次後停止,此時我們認為將如今的潛在變數z送入生成網路得到的假影象已經和圖片6非常像了,於是我們將z再次送入生成網路,得到G(z)。【注:由於潛在變數z送入的網路是生成圖片8的,儘管通過搜尋使G(z)和6儘可能相像,但還是存在一定差距,即它們的損失較大】 最後我們就可以計算G(z)和圖片6的損失,記為loss1,並將這個損失作為判斷是否有缺陷的重要依據。怎麼作為判斷是否有缺陷的重要依據呢?我再舉個例子大家就明白了,現在在測試階段我們傳入的不是缺陷資料,而是正常的資料8,此時應用相同的方法搜尋潛在變數z,然後將最終的z送入生成網路,得到G(z),最後計算G(z)和圖片8的損失。 【注:由於潛在變數z送入的網路是生成圖片8的,所以最後生成的G(z)可以和資料8很像,即它們的損失較小】 通過以上分析, 我們可以發現當我們在測試階段傳入缺陷圖片時最終的損失大,傳入正常圖片時的損失小,這時候我們就可以設定一個合適的閾值來判斷影象是否有缺陷了。🥂🥂🥂 這一段是整個AnoGAN的重點,大家多思考思考,相信你可以理解。我也畫了一個此過程的流程圖,大家可以參考一下,如下:
讀了上文,是不是對AnoGAN大致過程有了一定了解了呢!我覺得大家訓練階段肯定是沒問題的啦,就是一個DCGAN網路,不清楚這個的話建議閱讀[2]瞭解DCGAN網路。測試階段的難點就在於我們如何定義損失函式來更新z值,我們直接來看論文中此部分的損失,主要分為兩部分,分別是Residual Loss和Discrimination Loss,它們定義如下:
- Residual Loss
${\rm{R}}(z) = \sum {|x - G(z)|}$
上式z表示潛在變數,G(z)表示生成的假影象,x表示輸入的測試圖片。上式表示生成的假影象和輸入圖片之間的差距。如果生成的圖片越接近x,則R(z)越小。
- Discrimination Loss
$D(z) = \sum {|f(x) - f(G(z))|}$
上式z表示潛在變數,G(z)表示生成的假影象,x表示輸入的測試圖片。f()表示將通過判別器,然後取判別器某一層的輸出結果。 【注:這裡使用的並非判別器的最終輸出,而是判別器某層的輸出,關於這一點,會在程式碼講解時介紹】 這裡可以把判別器當作一個特徵提取網路,我們將生成的假圖片和測試圖片都輸入判別器,看它們提取到特徵的差異。同樣,如果生成的圖片越接近x,則D(z)越小。
求得R(z)和D(z)後,我們定義它們的線性組合作為最終的損失,如下:
$Loss(z)=(1-\lambda)R(z)+\lambda D(z)$
通常,我們取$\lambda =0.1$
到這裡,AnoGAN的理論部分都介紹完了喔!!!不知道你理解了多少呢?如果覺得有些地方理解還差點兒意思的話,就來看看下面的程式碼吧,這回對你理解AnoGAN非常有幫助。🌱🌱🌱
AnoGAN程式碼實戰
如果大家和我一樣找過AnoGAN程式碼的話,可能就會和我有一樣的感受,那就是太亂了。怎麼說呢,我認為從原理上來說,應該很好實現AnoGAN,但是我看Github上的程式碼寫的挺複雜,不是很好理解,有的甚至起著AnoGAN的名字,實現的卻是一個簡單的DCGAN網路,著實讓人有些無語。於是我打算按照自己的思路來實現一個AnoGAN,奈何卻出現了各種各樣的Bug,正當我心灰意冷時,看到了一篇外文的部落格,寫的非常對我的胃口,於是按照它的思路實現了AnoGAN。這裡我還是想感概一下,我發現很多外文的部落格確實寫的非常漂亮,我想這是值得我們學習的地方!!!🌼🌼🌼
程式碼下載地址✨✨✨
本次我將原始碼上傳到我的Github了,大家可以閱讀README檔案瞭解程式碼的使用,Github地址如下:
我認為你閱讀README檔案後已經對這個專案的結構有所瞭解,我在下文也會幫大家分析分析原始碼,但更多的時間大家應該自己動手去親自除錯,這樣你會有不一樣的收穫。🌾🌾🌾
資料讀取✨✨✨
本次使用的資料為mnist手寫數字資料集,我們下載的是.csv格式的資料,這種格式方便讀取。讀取資料程式碼如下:
python
## 讀取訓練集資料 (60000,785)
train = pd.read_csv(".\data\mnist_train.csv",dtype = np.float32)
## 讀取測試集資料 (10000,785)
test = pd.read_csv(".\data\mnist_test.csv",dtype = np.float32)
我們可以來看一下mnist資料集的格式是怎樣的,先來看看train中的內容,如下:
train的shape為(60000,785),其表示訓練集中共有60000個數據,即60000張手寫數字的圖片,每個資料都有785個值。我們來分析一下這785個數值的含義,第一個數值為標籤label,表示其表示哪個手寫數字,後784個數值為對應數字每個畫素的值,手寫數字圖片大小為28×28,故一共有784個畫素值。
解釋完訓練集資料的含義,那測試集也是一樣的啦,只不過資料較少,只有10000條資料,test的內容如下:
大家需要注意的是,上述的訓練集和測試集中的資料我們今天並不會全部用到。我們取訓練集中的前400個標籤為7或8的資料作為AnoGAN的訓練集,即7、8都為正常資料。取測試集前600個標籤為2、7、8作為測試資料,即測試集中有正常資料(7、8)和異常資料(2),相關程式碼如下:
python
# 查詢訓練資料中標籤為7、8的資料,並取前400個
train = train.query("label in [7.0, 8.0]").head(400)
# 查詢訓練資料中標籤為7、8的資料,並取前400個
test = test.query("label in [2.0, 7.0, 8.0]").head(600)
可以看看此時的train和test的結果:
在AnoGAN中,我們是無監督的學習,因此是不需要標籤的,通過以下程式碼去除train和test中的標籤:
python
# 取除標籤後的784列資料
train = train.iloc[:,1:].values.astype('float32')
test = test.iloc[:,1:].values.astype('float32')
去除標籤後train和test的結果如下:
可以看出,此時train和test中已經沒有了label類,它們的第二個維度也從785變成了784。
最後,我們將train和test reshape成圖片的格式,即28×28,程式碼如下:
python
# train:(400,784)-->(400,28,28)
# test:(600,784)-->(600,28,28)
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)
此時,train和test的維度發生變換,如下圖所示:
至此,我們的資料讀取部分就為大家介紹完了,是不是發現挺簡單的呢,加油吧!!!🥂🥂🥂
模型搭建
模型搭建真滴很簡單!!!大家之間看程式碼吧。🌻🌻🌻
生成模型搭建
python
"""定義生成器網路結構"""
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.ReLU(inplace=True), bn=True):
seq = []
seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
if bn is True:
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(20, 64*8, stride=1, padding=0)]
seq += [CBA(64*8, 64*4)]
seq += [CBA(64*4, 64*2)]
seq += [CBA(64*2, 64)]
seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]
self.generator_network = nn.Sequential(*seq)
def forward(self, z):
out = self.generator_network(z)
return out
為了幫助大家理解,我繪製 了生成網路的結構圖,如下:
判別模型搭建
python
"""定義判別器網路結構"""
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
seq = []
seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(1, 64)]
seq += [CBA(64, 64*2)]
seq += [CBA(64*2, 64*4)]
seq += [CBA(64*4, 64*8)]
self.feature_network = nn.Sequential(*seq)
self.critic_network = nn.Conv2d(64*8, 1, kernel_size=4, stride=1)
def forward(self, x):
out = self.feature_network(x)
feature = out
feature = feature.view(feature.size(0), -1)
out = self.critic_network(out)
return out, feature
同樣,為了方便大家理解,我也繪製了判別網路的結構圖,如下:
這裡大家需要稍稍注意一下,判別網路有兩個輸出,一個是最終的輸出,還有一個是第四個CBA BLOCK提取到的特徵,這個在理論部分介紹損失函式時有提及。
模型訓練
資料集載入
python
class image_data_set(Dataset):
def __init__(self, data):
self.images = data[:,:,:,None]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),
transforms.Normalize((0.1307,), (0.3081,))
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.transform(self.images[idx])
# 載入訓練資料
train_set = image_data_set(train)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
這部分不難,但我提醒大家注意一下這句:transforms.Resize(64, interpolation=InterpolationMode.BICUBIC)
,即我們採用插值演算法將原來2828大小的圖片上取樣成了6464大小。 【感興趣的這裡也可以不對其進行上取樣,這樣的話大家需要修改一下上節的模型,可以試試效果喔】
載入模型、定義優化器、損失函式等引數
python
# 指定裝置
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# batch_size預設128
batch_size = args.batch_size
# 載入模型
G = Generator().to(device)
D = Discriminator().to(device)
# 訓練模式
G.train()
D.train()
# 設定優化器
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0004, betas=(0.0, 0.9))
# 定義損失函式
criterion = nn.BCEWithLogitsLoss(reduction='mean')
訓練GAN網路
python
"""
訓練
"""
# 開始訓練
for epoch in range(args.epochs):
# 定義初始損失
log_g_loss, log_d_loss = 0.0, 0.0
for images in train_loader:
images = images.to(device)
## 訓練判別器 Discriminator
# 定義真標籤(全1)和假標籤(全0) 維度:(batch_size)
label_real = torch.full((images.size(0),), 1.0).to(device)
label_fake = torch.full((images.size(0),), 0.0).to(device)
# 定義潛在變數z 維度:(batch_size,20,1,1)
z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
# 潛在變數喂入生成網路--->fake_images:(batch_size,1,64,64)
fake_images = G(z)
# 真影象和假影象送入判別網路,得到d_out_real、d_out_fake 維度:都為(batch_size,1,1,1)
d_out_real, _ = D(images)
d_out_fake, _ = D(fake_images)
# 損失計算
d_loss_real = criterion(d_out_real.view(-1), label_real)
d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
d_loss = d_loss_real + d_loss_fake
# 誤差反向傳播,更新損失
optimizerD.zero_grad()
d_loss.backward()
optimizerD.step()
## 訓練生成器 Generator
# 定義潛在變數z 維度:(batch_size,20,1,1)
z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
fake_images = G(z)
# 假影象喂入判別器,得到d_out_fake 維度:(batch_size,1,1,1)
d_out_fake, _ = D(fake_images)
# 損失計算
g_loss = criterion(d_out_fake.view(-1), label_real)
# 誤差反向傳播,更新損失
optimizerG.zero_grad()
g_loss.backward()
optimizerG.step()
## 累計一個epoch的損失,判別器損失和生成器損失分別存放到log_d_loss、log_g_loss中
log_d_loss += d_loss.item()
log_g_loss += g_loss.item()
## 列印損失
print(f'epoch {epoch}, D_Loss:{log_d_loss / 128:.4f}, G_Loss:{log_g_loss / 128:.4f}')
## 展示生成器儲存的圖片,存放在result資料夾下的G_out.jpg
z = torch.randn(8, 20).to(device).view(8, 20, 1, 1).to(device)
fake_images = G(z)
torchvision.utils.save_image(fake_images,f"result\G_out.jpg")
這部分就是訓練一個DCGAN網路,到目前為止其實也都可以認為是DCGAN的內容。我們可以來看一下輸出的G_out.jpg
圖片:
這裡我們可以看到訓練是有了效果的,但會發現不是特別好。我分析有兩點原因,其一是我們的模型不好,且GAN本身就容易出現模式崩潰的問題;其二是我們的資料選取的少,在資料讀取時訓練集我們只取了前400個數據,但實際上我們一共可以取12116個,大家可以嘗試增加資料,我想資料多了後效果肯定比這個好,大家快去試試吧!!!🍉🍉🍉
缺陷檢測✨✨✨
這部分才是AnoGAN的重點,首先我們先定義損失的計算,如下:
python
## 定義缺陷計算的得分
def anomaly_score(input_image, fake_image, D):
# Residual loss 計算
residual_loss = torch.sum(torch.abs(input_image - fake_image), (1, 2, 3))
# Discrimination loss 計算
_, real_feature = D(input_image)
_, fake_feature = D(fake_image)
discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), (1))
# 結合Residual loss和Discrimination loss計算每張影象的損失
total_loss_by_image = 0.9 * residual_loss + 0.1 * discrimination_loss
# 計算總損失,即將一個batch的損失相加
total_loss = total_loss_by_image.sum()
return total_loss, total_loss_by_image, residual_loss
大家可以對比一下理論部分損失函式的介紹,看看是不是一樣的呢。
接著我們就需要不斷的搜尋潛在變數z了,使其與輸入圖片儘可能接近,程式碼如下:
python
# 載入測試資料
test_set = image_data_set(test)
test_loader = DataLoader(test_set, batch_size=5, shuffle=False)
input_images = next(iter(test_loader)).to(device)
# 定義潛在變數z 維度:(5,20,1,1)
z = torch.randn(5, 20).to(device).view(5, 20, 1, 1)
# z的requires_grad引數設定成Ture,讓z可以更新
z.requires_grad = True
# 定義優化器
z_optimizer = torch.optim.Adam([z], lr=1e-3)
# 搜尋z
for epoch in range(5000):
fake_images = G(z)
loss, _, _ = anomaly_score(input_images, fake_images, D)
z_optimizer.zero_grad()
loss.backward()
z_optimizer.step()
if epoch % 1000 == 0:
print(f'epoch: {epoch}, loss: {loss:.0f}')
執行完上述程式碼後,我們得到了一個較理想的潛在變數,這時候再用z來生成圖片,並基於生成圖片和輸入圖片來計算損失,同時,我們也儲存了輸入圖片和生成圖片,並列印了它們之前的損失,相關程式碼如下:
python
fake_images = G(z)
_, total_loss_by_image, _ = anomaly_score(input_images, fake_images, D)
print(total_loss_by_image.cpu().detach().numpy())
torchvision.utils.save_image(input_images, f"result/Nomal.jpg")
torchvision.utils.save_image(fake_images, f"result/ANomal.jpg")
我們可以來看看最後的結果哦,如下:
可以看到,當輸入影象為2時(此為缺陷),生成的影象也是8,它們的損失最高為464040.44。這時候如果我們設定一個閾值為430000,高於這個閾值的即為異常圖片,低於這個閾值的即為正常圖片,那麼我們是不是就可以通過AnoGAN來實現缺陷的檢測了呢!!!🍒🍒🍒
總結
到這裡,AnoGAN的所有內容就介紹完了,大家好好感受感受它的思想,其實是很簡單的,但是又非常巧妙。最後我不知道大家有沒有發現AnoGAN一個非常明顯的缺陷,那就是我們每次在判斷異常時要不斷的搜尋潛在變數z,這是非常耗時的。而很多工對時間的要求還是很高的,所以AnoGAN還有許多可以改進的地方,後續博文我會帶大家繼續學習GAN網路在缺陷檢測中的應用,我們下期見。🖐🏽🖐🏽🖐🏽
參考文獻
AnoGAN論文🍁🍁🍁
深度學習論文筆記(異常檢測) 🍁🍁🍁
如若文章對你有所幫助,那就🛴🛴🛴
咻咻咻咻~~duang\~~點個讚唄
- 兔年到了,一起來寫個春聯吧
- CV攻城獅入門VIT(vision transformer)之旅——VIT程式碼實戰篇
- 對抗生成網路GAN系列——GANomaly原理及原始碼解析
- 對抗生成網路GAN系列——WGAN原理及實戰演練
- CV攻城獅入門VIT(vision transformer)之旅——近年超火的Transformer你再不瞭解就晚了!
- 對抗生成網路GAN系列——DCGAN簡介及人臉影象生成案例
- 對抗生成網路GAN系列——CycleGAN簡介及圖片春冬變換案例
- 對抗生成網路GAN系列——AnoGAN原理及缺陷檢測實戰
- 目標檢測系列——Faster R-CNN原理詳解
- 目標檢測系列——Fast R-CNN原理詳解
- 目標檢測系列——開山之作RCNN原理詳解
- 【古月21講】ROS入門系列(4)——引數使用與程式設計方法、座標管理系統、tf座標系廣播與監聽的程式設計實現、launch啟動檔案的使用方法
- 使用kitti資料集實現自動駕駛——繪製出所有物體的行駛軌跡
- 使用kitti資料集實現自動駕駛——釋出照片、點雲、IMU、GPS、顯示2D和3D偵測框
- 基於pytorch搭建ResNet神經網路用於花類識別
- 基於pytorch搭建GoogleNet神經網路用於花類識別
- 基於pytorch搭建VGGNet神經網路用於花類識別
- UWB原理分析
- 論文閱讀:RRPN:RADAR REGION PROPOSAL NETWORK FOR OBJECT DETECTION IN AUTONOMOUS
- 凸優化理論基礎2——凸集和錐