基於pytorch搭建VGGNet神經網絡用於花類識別
持續創作,加速成長!這是我參與「掘金日新計劃 · 6 月更文挑戰」的第9天,點擊查看活動詳情
🍊作者簡介:禿頭小蘇,致力於用最通俗的語言描述問題
🍊往期回顧:卡爾曼濾波系列1——卡爾曼濾波 基於pytorch搭建AlexNet神經網絡用於花類識別
🍊近期目標:擁有5000粉絲
🍊支持小蘇:點贊👍🏼、收藏⭐、留言📩
基於pytorch搭建VGGNet神經網絡用於花類識別
寫在前面
上一篇寫過基於pytorch搭建AlexNet神經網絡用於花類識別項目實戰,建議閲讀此篇前先弄明白上篇所述之事🍍🍍🍍此外本節搭建的網絡模型是VGG,需要你對VGG的網絡結構有較深入的瞭解,還不清楚的戳此圖標☞☞☞瞭解詳情。
這篇文章同樣是對花的類別進行識別,和上一篇使用AlexNet進行識別整體步驟是完全類似的,主要區別就是網絡的結構有所不同,因此,本節將只針對VGG的網絡結構搭建進行詳細的講解,其餘部分基本和上一篇一致,不再贅述,大家可自行下載代碼進一步研究。
VGGNet網絡模型搭建
這一部分的代碼可能真的能讓你感受到代碼之美,寫的確實太漂亮了🍁🍁🍁首先我們知道VGG一共有四種結構,分別為VGG11、VGG13、VGG16、VGG19。我想若是讓我們單獨的構建一種VGG網絡是不難辦到的,VGG這種直筒型的結構用代碼實現是較容易的。官方的demo中通過一個字典將4中結構的VGG網絡放在了一起,只需要我們在調用的時候傳入相關參數就可以了,實在是太妙了!!!下面讓我們一起來學習一下🥗🥗🥗
首先我們定義了一個字典cfgs
,字典中有四個鍵值對,每個鍵對應VGG的一種結構,每個值是對應結構中的一些參數。
python
cfgs = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
具體的,我們拿vgg16來進行相關解釋:
有了這個字典之後,我們就可以通過傳入相關參數來構建特徵提取層:
```python
在cfgs傳入"vgg16",得到一個列表cfg
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
cfg = cfgs["vgg16"]
定義特徵提取層函數make_features,將cfg作為參數傳入
def make_features(cfg: list): layers = [] in_channels = 3 #遍歷整個cfg列表 for v in cfg: #若v的值為"M",則在層結構layers中添加一個最大池化層,其kernel_size=2, stride=2 if v == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] #若v值為數字,則在層結構layers中添加一個卷積層核和Relu激活函數 else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) layers += [conv2d, nn.ReLU(True)] in_channels = v return nn.Sequential(*layers)
``` 通過上面程序,我們可以來看一下得到的layers內部的部分結構【只選了前幾個層】,如下圖所示:
接下來我們可以來構建我們的分類層,即全連接層的部分:
python
self.classifier = nn.Sequential(
nn.Linear(512*7*7, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, num_classes)
)
這些都準備好之後,我們就可以來定義我們的網絡模型了,如下所示:
```python class VGG(nn.Module): def init(self, features, num_classes=1000): super(VGG, self).init() self.features = features self.classifier = nn.Sequential( nn.Linear(51277, 4096), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(4096, num_classes) ) if init_weights: self._initialize_weights()
def forward(self, x):
# N x 3 x 224 x 224
x = self.features(x)
# N x 512 x 7 x 7
x = torch.flatten(x, start_dim=1)
# N x 512*7*7
x = self.classifier(x)
return x
``` 至此,我們的模型就創建完畢,最後讓我們來看看我們剛剛創建的VGG模型結構:
訓練結果展示
本篇文章不再詳細講解訓練步驟,和基於pytorch搭建AlexNet神經網絡用於花類識別基本一致。這裏展示一下訓練結果,如下圖所示:
其準確率達到了0.761,我們可以再來看看我們保存的VGG模型,如下圖,可以看出VGG用到的參數還是很多的,有500+M,這和我們的理論部分也是契合的。
小結
對於這一部分我強烈建議大家去使用Pycharm的調試功能,一步步的看每次運行的結果,這樣你會發現代碼結構特別的清晰。
參考視頻:http://www.bilibili.com/video/BV1i7411T7ZN/?spm_id_from=333.788🌸🌸🌸
如若文章對你有所幫助,那就🛴🛴🛴
咻咻咻咻~~duang\~~點個讚唄
- 兔年到了,一起來寫個春聯吧
- CV攻城獅入門VIT(vision transformer)之旅——VIT代碼實戰篇
- 對抗生成網絡GAN系列——GANomaly原理及源碼解析
- 對抗生成網絡GAN系列——WGAN原理及實戰演練
- CV攻城獅入門VIT(vision transformer)之旅——近年超火的Transformer你再不瞭解就晚了!
- 對抗生成網絡GAN系列——DCGAN簡介及人臉圖像生成案例
- 對抗生成網絡GAN系列——CycleGAN簡介及圖片春冬變換案例
- 對抗生成網絡GAN系列——AnoGAN原理及缺陷檢測實戰
- 目標檢測系列——Faster R-CNN原理詳解
- 目標檢測系列——Fast R-CNN原理詳解
- 目標檢測系列——開山之作RCNN原理詳解
- 【古月21講】ROS入門系列(4)——參數使用與編程方法、座標管理系統、tf座標系廣播與監聽的編程實現、launch啟動文件的使用方法
- 使用kitti數據集實現自動駕駛——繪製出所有物體的行駛軌跡
- 使用kitti數據集實現自動駕駛——發佈照片、點雲、IMU、GPS、顯示2D和3D偵測框
- 基於pytorch搭建ResNet神經網絡用於花類識別
- 基於pytorch搭建GoogleNet神經網絡用於花類識別
- 基於pytorch搭建VGGNet神經網絡用於花類識別
- UWB原理分析
- 論文閲讀:RRPN:RADAR REGION PROPOSAL NETWORK FOR OBJECT DETECTION IN AUTONOMOUS
- 凸優化理論基礎2——凸集和錐