對抗生成網絡GAN系列——AnoGAN原理及缺陷檢測實戰
theme: fancy
本文為稀土掘金技術社區首發簽約文章,14天內禁止轉載,14天后未獲授權禁止轉載,侵權必究!
🍊作者簡介:禿頭小蘇,致力於用最通俗的語言描述問題
🍊往期回顧:對抗生成網絡GAN系列——GAN原理及手寫數字生成小案例 對抗生成網絡GAN系列——DCGAN簡介及人臉圖像生成案例 對抗生成網絡GAN系列——CycleGAN簡介及圖片春冬變換案例
🍊近期目標:寫好專欄的每一篇文章
🍊支持小蘇:點贊👍🏼、收藏⭐、留言📩
對抗生成網絡GAN系列——AnoGAN原理及缺陷檢測實戰
寫在前面
隨着深度學習的發展,已經有很多學者將深度學習應用到物體瑕疵檢測中,如列車鋼軌的缺陷檢測、醫學影像中各種疾病的檢測。但是瑕疵檢測任務幾乎都存在一個共同的難題——缺陷數據太少了。我們使用這些稀少的缺陷數據很難利用深度學習訓練一個理想的模型,往往都需要進行數據擴充,即通過某些手段增加我們的缺陷數據。 【數據擴充大家感興趣自己去了解下,GAN網絡也是實現數據擴充的主流手段】 上面説到的方法是基於缺陷數據來訓練的,是有監督的學習,學者們在漫長的研究中,考慮能不能使用一種無監督的方法來實現缺陷檢測呢?於是啊,AnoGAN就橫空出世了,它不需要缺陷數據進行訓練,而僅使用正常數據訓練模型,關於AnoGAN的細節後文詳細介紹。
關於GAN網絡,我已經介紹了幾篇,如下:
- [1]對抗生成網絡GAN系列——GAN原理及手寫數字生成小案例 🍁🍁🍁
- [2]對抗生成網絡GAN系列——DCGAN簡介及人臉圖像生成案例🍁🍁🍁
- [3]對抗生成網絡GAN系列——CycleGAN原理🍁🍁🍁
在閲讀本文之前建議大家對GAN有一定的瞭解,可以參考[1]和[2],關於[3]感興趣的可以看看,本篇文章用不到[3]相關知識。
準備好了嘛,我們開始發車了喔。🚖🚖🚖
AnoGAN 原理詳解✨✨✨
首先我們來看看AnoGAN的全稱,即Anomaly Detection with Generative Adversarial Networks
,中文是指使用生成對抗網絡實現異常檢測。這篇論文解決的是醫學影像中疾病的檢測,由於對醫學相關內容不瞭解,本文將完全將該算法從論文中剝離,只介紹算法原理,而不結合論文進行講述。想要了解論文詳情的可以點擊☞☞☞查看。
接下來就隨我一起來看看AnoGAN的原理。其實AnoGAN的原理是很簡單的,但是我看網上的資料總是説的摸稜兩可,我認為主要原因有兩點:其一是沒有把AnoGAN的原理分步來敍述,其二是有專家視角,它們認為我們都應該明白,但這對於新手來説理解也確實是有一定難度的。
在介紹AnoGAN的具體原理時,我先來談談AnoGAN的出發點,這非常重要,大家好好感受。我們知道,DCGAN是將一個噪聲或者説一個潛在變量映射成一張圖片,在我們訓練DCGAN時,都是使用某一種數據進行的,如[2]中使用的數據都是人臉,那麼這些數據都是正常數據,我們從一個潛在變量經DCGAN後生成的圖片應該也都是正常圖像。AnoGAN的想法就是我能否將一張圖片M映射成某個潛在變量呢,這其實是較難做到的。但是我們可以在某個空間不斷的查找一個潛在變量,使得這個潛在變量生成的圖片與圖片M儘可能接近。這就是AnoGAN的出發點,大家可能還不明白這麼做的意義,下文為大家詳細介紹。☘☘☘
AnoGAN其實是分兩個階段進行的,首先是訓練階段,然後是測試階段,我們一點點來看:
-
訓練階段
訓練階段僅使用正常的數據訓練對抗生成網絡。如我們使用手寫數字中的數字8作為本階段的數據進行訓練,那麼8就是正常數據。訓練結束後我們輸入一個向量z,生成網絡會將z變成8。不知道大家有沒有發現其實這階段就是[2]中的DCGAN呢? 【注意:訓練階段已經訓練好GAN網絡,後面的測試階段GAN網絡的權重是不在變換的】
-
測試階段
在訓練階段我們已經訓練好了一個GAN網絡,在這一階段我們就是要利用訓練好的網絡來進行缺陷檢測。如現在我們有一個數據6,此為缺陷數據 【訓練時使用8進行訓練,這裏的6即為缺陷數據】 。現在我們要做的就是搜索一個潛在變量並讓其生成的圖片與圖片6儘可能接近,具體實現如下:首先我們會定義一個潛在變量z,然後經過剛剛訓練的好的生成網絡,得到假圖像G(z),接着G(z)和缺陷數據6計算損失,這時候損失往往會比較大,我們不斷的更新z值,會使損失不斷的減少,在程序中我們可以設置更新z的次數,如更新500次後停止,此時我們認為將如今的潛在變量z送入生成網絡得到的假圖像已經和圖片6非常像了,於是我們將z再次送入生成網絡,得到G(z)。【注:由於潛在變量z送入的網絡是生成圖片8的,儘管通過搜索使G(z)和6儘可能相像,但還是存在一定差距,即它們的損失較大】 最後我們就可以計算G(z)和圖片6的損失,記為loss1,並將這個損失作為判斷是否有缺陷的重要依據。怎麼作為判斷是否有缺陷的重要依據呢?我再舉個例子大家就明白了,現在在測試階段我們傳入的不是缺陷數據,而是正常的數據8,此時應用相同的方法搜索潛在變量z,然後將最終的z送入生成網絡,得到G(z),最後計算G(z)和圖片8的損失。 【注:由於潛在變量z送入的網絡是生成圖片8的,所以最後生成的G(z)可以和數據8很像,即它們的損失較小】 通過以上分析, 我們可以發現當我們在測試階段傳入缺陷圖片時最終的損失大,傳入正常圖片時的損失小,這時候我們就可以設置一個合適的閾值來判斷圖像是否有缺陷了。🥂🥂🥂 這一段是整個AnoGAN的重點,大家多思考思考,相信你可以理解。我也畫了一個此過程的流程圖,大家可以參考一下,如下:
讀了上文,是不是對AnoGAN大致過程有了一定了解了呢!我覺得大家訓練階段肯定是沒問題的啦,就是一個DCGAN網絡,不清楚這個的話建議閲讀[2]瞭解DCGAN網絡。測試階段的難點就在於我們如何定義損失函數來更新z值,我們直接來看論文中此部分的損失,主要分為兩部分,分別是Residual Loss和Discrimination Loss,它們定義如下:
- Residual Loss
${\rm{R}}(z) = \sum {|x - G(z)|}$
上式z表示潛在變量,G(z)表示生成的假圖像,x表示輸入的測試圖片。上式表示生成的假圖像和輸入圖片之間的差距。如果生成的圖片越接近x,則R(z)越小。
- Discrimination Loss
$D(z) = \sum {|f(x) - f(G(z))|}$
上式z表示潛在變量,G(z)表示生成的假圖像,x表示輸入的測試圖片。f()表示將通過判別器,然後取判別器某一層的輸出結果。 【注:這裏使用的並非判別器的最終輸出,而是判別器某層的輸出,關於這一點,會在代碼講解時介紹】 這裏可以把判別器當作一個特徵提取網絡,我們將生成的假圖片和測試圖片都輸入判別器,看它們提取到特徵的差異。同樣,如果生成的圖片越接近x,則D(z)越小。
求得R(z)和D(z)後,我們定義它們的線性組合作為最終的損失,如下:
$Loss(z)=(1-\lambda)R(z)+\lambda D(z)$
通常,我們取$\lambda =0.1$
到這裏,AnoGAN的理論部分都介紹完了喔!!!不知道你理解了多少呢?如果覺得有些地方理解還差點兒意思的話,就來看看下面的代碼吧,這回對你理解AnoGAN非常有幫助。🌱🌱🌱
AnoGAN代碼實戰
如果大家和我一樣找過AnoGAN代碼的話,可能就會和我有一樣的感受,那就是太亂了。怎麼説呢,我認為從原理上來説,應該很好實現AnoGAN,但是我看Github上的代碼寫的挺複雜,不是很好理解,有的甚至起着AnoGAN的名字,實現的卻是一個簡單的DCGAN網絡,着實讓人有些無語。於是我打算按照自己的思路來實現一個AnoGAN,奈何卻出現了各種各樣的Bug,正當我心灰意冷時,看到了一篇外文的博客,寫的非常對我的胃口,於是按照它的思路實現了AnoGAN。這裏我還是想感概一下,我發現很多外文的博客確實寫的非常漂亮,我想這是值得我們學習的地方!!!🌼🌼🌼
代碼下載地址✨✨✨
本次我將源碼上傳到我的Github了,大家可以閲讀README文件瞭解代碼的使用,Github地址如下:
我認為你閲讀README文件後已經對這個項目的結構有所瞭解,我在下文也會幫大家分析分析源碼,但更多的時間大家應該自己動手去親自調試,這樣你會有不一樣的收穫。🌾🌾🌾
數據讀取✨✨✨
本次使用的數據為mnist手寫數字數據集,我們下載的是.csv格式的數據,這種格式方便讀取。讀取數據代碼如下:
python
## 讀取訓練集數據 (60000,785)
train = pd.read_csv(".\data\mnist_train.csv",dtype = np.float32)
## 讀取測試集數據 (10000,785)
test = pd.read_csv(".\data\mnist_test.csv",dtype = np.float32)
我們可以來看一下mnist數據集的格式是怎樣的,先來看看train中的內容,如下:
train的shape為(60000,785),其表示訓練集中共有60000個數據,即60000張手寫數字的圖片,每個數據都有785個值。我們來分析一下這785個數值的含義,第一個數值為標籤label,表示其表示哪個手寫數字,後784個數值為對應數字每個像素的值,手寫數字圖片大小為28×28,故一共有784個像素值。
解釋完訓練集數據的含義,那測試集也是一樣的啦,只不過數據較少,只有10000條數據,test的內容如下:
大家需要注意的是,上述的訓練集和測試集中的數據我們今天並不會全部用到。我們取訓練集中的前400個標籤為7或8的數據作為AnoGAN的訓練集,即7、8都為正常數據。取測試集前600個標籤為2、7、8作為測試數據,即測試集中有正常數據(7、8)和異常數據(2),相關代碼如下:
python
# 查詢訓練數據中標籤為7、8的數據,並取前400個
train = train.query("label in [7.0, 8.0]").head(400)
# 查詢訓練數據中標籤為7、8的數據,並取前400個
test = test.query("label in [2.0, 7.0, 8.0]").head(600)
可以看看此時的train和test的結果:
在AnoGAN中,我們是無監督的學習,因此是不需要標籤的,通過以下代碼去除train和test中的標籤:
python
# 取除標籤後的784列數據
train = train.iloc[:,1:].values.astype('float32')
test = test.iloc[:,1:].values.astype('float32')
去除標籤後train和test的結果如下:
可以看出,此時train和test中已經沒有了label類,它們的第二個維度也從785變成了784。
最後,我們將train和test reshape成圖片的格式,即28×28,代碼如下:
python
# train:(400,784)-->(400,28,28)
# test:(600,784)-->(600,28,28)
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)
此時,train和test的維度發生變換,如下圖所示:
至此,我們的數據讀取部分就為大家介紹完了,是不是發現挺簡單的呢,加油吧!!!🥂🥂🥂
模型搭建
模型搭建真滴很簡單!!!大家之間看代碼吧。🌻🌻🌻
生成模型搭建
python
"""定義生成器網絡結構"""
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.ReLU(inplace=True), bn=True):
seq = []
seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
if bn is True:
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(20, 64*8, stride=1, padding=0)]
seq += [CBA(64*8, 64*4)]
seq += [CBA(64*4, 64*2)]
seq += [CBA(64*2, 64)]
seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]
self.generator_network = nn.Sequential(*seq)
def forward(self, z):
out = self.generator_network(z)
return out
為了幫助大家理解,我繪製 了生成網絡的結構圖,如下:
判別模型搭建
python
"""定義判別器網絡結構"""
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
seq = []
seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(1, 64)]
seq += [CBA(64, 64*2)]
seq += [CBA(64*2, 64*4)]
seq += [CBA(64*4, 64*8)]
self.feature_network = nn.Sequential(*seq)
self.critic_network = nn.Conv2d(64*8, 1, kernel_size=4, stride=1)
def forward(self, x):
out = self.feature_network(x)
feature = out
feature = feature.view(feature.size(0), -1)
out = self.critic_network(out)
return out, feature
同樣,為了方便大家理解,我也繪製了判別網絡的結構圖,如下:
這裏大家需要稍稍注意一下,判別網絡有兩個輸出,一個是最終的輸出,還有一個是第四個CBA BLOCK提取到的特徵,這個在理論部分介紹損失函數時有提及。
模型訓練
數據集加載
python
class image_data_set(Dataset):
def __init__(self, data):
self.images = data[:,:,:,None]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),
transforms.Normalize((0.1307,), (0.3081,))
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.transform(self.images[idx])
# 加載訓練數據
train_set = image_data_set(train)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
這部分不難,但我提醒大家注意一下這句:transforms.Resize(64, interpolation=InterpolationMode.BICUBIC)
,即我們採用插值算法將原來2828大小的圖片上採樣成了6464大小。 【感興趣的這裏也可以不對其進行上採樣,這樣的話大家需要修改一下上節的模型,可以試試效果喔】
加載模型、定義優化器、損失函數等參數
python
# 指定設備
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# batch_size默認128
batch_size = args.batch_size
# 加載模型
G = Generator().to(device)
D = Discriminator().to(device)
# 訓練模式
G.train()
D.train()
# 設置優化器
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0004, betas=(0.0, 0.9))
# 定義損失函數
criterion = nn.BCEWithLogitsLoss(reduction='mean')
訓練GAN網絡
python
"""
訓練
"""
# 開始訓練
for epoch in range(args.epochs):
# 定義初始損失
log_g_loss, log_d_loss = 0.0, 0.0
for images in train_loader:
images = images.to(device)
## 訓練判別器 Discriminator
# 定義真標籤(全1)和假標籤(全0) 維度:(batch_size)
label_real = torch.full((images.size(0),), 1.0).to(device)
label_fake = torch.full((images.size(0),), 0.0).to(device)
# 定義潛在變量z 維度:(batch_size,20,1,1)
z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
# 潛在變量喂入生成網絡--->fake_images:(batch_size,1,64,64)
fake_images = G(z)
# 真圖像和假圖像送入判別網絡,得到d_out_real、d_out_fake 維度:都為(batch_size,1,1,1)
d_out_real, _ = D(images)
d_out_fake, _ = D(fake_images)
# 損失計算
d_loss_real = criterion(d_out_real.view(-1), label_real)
d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
d_loss = d_loss_real + d_loss_fake
# 誤差反向傳播,更新損失
optimizerD.zero_grad()
d_loss.backward()
optimizerD.step()
## 訓練生成器 Generator
# 定義潛在變量z 維度:(batch_size,20,1,1)
z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
fake_images = G(z)
# 假圖像喂入判別器,得到d_out_fake 維度:(batch_size,1,1,1)
d_out_fake, _ = D(fake_images)
# 損失計算
g_loss = criterion(d_out_fake.view(-1), label_real)
# 誤差反向傳播,更新損失
optimizerG.zero_grad()
g_loss.backward()
optimizerG.step()
## 累計一個epoch的損失,判別器損失和生成器損失分別存放到log_d_loss、log_g_loss中
log_d_loss += d_loss.item()
log_g_loss += g_loss.item()
## 打印損失
print(f'epoch {epoch}, D_Loss:{log_d_loss / 128:.4f}, G_Loss:{log_g_loss / 128:.4f}')
## 展示生成器存儲的圖片,存放在result文件夾下的G_out.jpg
z = torch.randn(8, 20).to(device).view(8, 20, 1, 1).to(device)
fake_images = G(z)
torchvision.utils.save_image(fake_images,f"result\G_out.jpg")
這部分就是訓練一個DCGAN網絡,到目前為止其實也都可以認為是DCGAN的內容。我們可以來看一下輸出的G_out.jpg
圖片:
這裏我們可以看到訓練是有了效果的,但會發現不是特別好。我分析有兩點原因,其一是我們的模型不好,且GAN本身就容易出現模式崩潰的問題;其二是我們的數據選取的少,在數據讀取時訓練集我們只取了前400個數據,但實際上我們一共可以取12116個,大家可以嘗試增加數據,我想數據多了後效果肯定比這個好,大家快去試試吧!!!🍉🍉🍉
缺陷檢測✨✨✨
這部分才是AnoGAN的重點,首先我們先定義損失的計算,如下:
python
## 定義缺陷計算的得分
def anomaly_score(input_image, fake_image, D):
# Residual loss 計算
residual_loss = torch.sum(torch.abs(input_image - fake_image), (1, 2, 3))
# Discrimination loss 計算
_, real_feature = D(input_image)
_, fake_feature = D(fake_image)
discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), (1))
# 結合Residual loss和Discrimination loss計算每張圖像的損失
total_loss_by_image = 0.9 * residual_loss + 0.1 * discrimination_loss
# 計算總損失,即將一個batch的損失相加
total_loss = total_loss_by_image.sum()
return total_loss, total_loss_by_image, residual_loss
大家可以對比一下理論部分損失函數的介紹,看看是不是一樣的呢。
接着我們就需要不斷的搜索潛在變量z了,使其與輸入圖片儘可能接近,代碼如下:
python
# 加載測試數據
test_set = image_data_set(test)
test_loader = DataLoader(test_set, batch_size=5, shuffle=False)
input_images = next(iter(test_loader)).to(device)
# 定義潛在變量z 維度:(5,20,1,1)
z = torch.randn(5, 20).to(device).view(5, 20, 1, 1)
# z的requires_grad參數設置成Ture,讓z可以更新
z.requires_grad = True
# 定義優化器
z_optimizer = torch.optim.Adam([z], lr=1e-3)
# 搜索z
for epoch in range(5000):
fake_images = G(z)
loss, _, _ = anomaly_score(input_images, fake_images, D)
z_optimizer.zero_grad()
loss.backward()
z_optimizer.step()
if epoch % 1000 == 0:
print(f'epoch: {epoch}, loss: {loss:.0f}')
執行完上述代碼後,我們得到了一個較理想的潛在變量,這時候再用z來生成圖片,並基於生成圖片和輸入圖片來計算損失,同時,我們也保存了輸入圖片和生成圖片,並打印了它們之前的損失,相關代碼如下:
python
fake_images = G(z)
_, total_loss_by_image, _ = anomaly_score(input_images, fake_images, D)
print(total_loss_by_image.cpu().detach().numpy())
torchvision.utils.save_image(input_images, f"result/Nomal.jpg")
torchvision.utils.save_image(fake_images, f"result/ANomal.jpg")
我們可以來看看最後的結果哦,如下:
可以看到,當輸入圖像為2時(此為缺陷),生成的圖像也是8,它們的損失最高為464040.44。這時候如果我們設置一個閾值為430000,高於這個閾值的即為異常圖片,低於這個閾值的即為正常圖片,那麼我們是不是就可以通過AnoGAN來實現缺陷的檢測了呢!!!🍒🍒🍒
總結
到這裏,AnoGAN的所有內容就介紹完了,大家好好感受感受它的思想,其實是很簡單的,但是又非常巧妙。最後我不知道大家有沒有發現AnoGAN一個非常明顯的缺陷,那就是我們每次在判斷異常時要不斷的搜索潛在變量z,這是非常耗時的。而很多任務對時間的要求還是很高的,所以AnoGAN還有許多可以改進的地方,後續博文我會帶大家繼續學習GAN網絡在缺陷檢測中的應用,我們下期見。🖐🏽🖐🏽🖐🏽
參考文獻
AnoGAN論文🍁🍁🍁
深度學習論文筆記(異常檢測) 🍁🍁🍁
如若文章對你有所幫助,那就🛴🛴🛴
咻咻咻咻~~duang\~~點個讚唄
- 兔年到了,一起來寫個春聯吧
- CV攻城獅入門VIT(vision transformer)之旅——VIT代碼實戰篇
- 對抗生成網絡GAN系列——GANomaly原理及源碼解析
- 對抗生成網絡GAN系列——WGAN原理及實戰演練
- CV攻城獅入門VIT(vision transformer)之旅——近年超火的Transformer你再不瞭解就晚了!
- 對抗生成網絡GAN系列——DCGAN簡介及人臉圖像生成案例
- 對抗生成網絡GAN系列——CycleGAN簡介及圖片春冬變換案例
- 對抗生成網絡GAN系列——AnoGAN原理及缺陷檢測實戰
- 目標檢測系列——Faster R-CNN原理詳解
- 目標檢測系列——Fast R-CNN原理詳解
- 目標檢測系列——開山之作RCNN原理詳解
- 【古月21講】ROS入門系列(4)——參數使用與編程方法、座標管理系統、tf座標系廣播與監聽的編程實現、launch啟動文件的使用方法
- 使用kitti數據集實現自動駕駛——繪製出所有物體的行駛軌跡
- 使用kitti數據集實現自動駕駛——發佈照片、點雲、IMU、GPS、顯示2D和3D偵測框
- 基於pytorch搭建ResNet神經網絡用於花類識別
- 基於pytorch搭建GoogleNet神經網絡用於花類識別
- 基於pytorch搭建VGGNet神經網絡用於花類識別
- UWB原理分析
- 論文閲讀:RRPN:RADAR REGION PROPOSAL NETWORK FOR OBJECT DETECTION IN AUTONOMOUS
- 凸優化理論基礎2——凸集和錐