全卷積網路(FCN)實戰:使用FCN實現語義分割

語言: CN / TW / HK
摘要:FCN對影象進行畫素級的分類,從而解決了語義級別的影象分割問題。

本文分享自華為雲社群《全卷積網路(FCN)實戰:使用FCN實現語義分割》,作者: AI浩。

FCN對影象進行畫素級的分類,從而解決了語義級別的影象分割(semantic segmentation)問題。與經典的CNN在卷積層之後使用全連線層得到固定長度的特徵向量進行分類(全聯接層+softmax輸出)不同,FCN可以接受任意尺寸的輸入影象,採用反捲積層對最後一個卷積層的feature map進行上取樣, 使它恢復到輸入影象相同的尺寸,從而可以對每個畫素都產生了一個預測, 同時保留了原始輸入影象中的空間資訊, 最後在上取樣的特徵圖上進行逐畫素分類。

下圖是語義分割所採用的全卷積網路(FCN)的結構示意圖:

傳統的基於CNN的分割方法缺點?

傳統的基於CNN的分割方法:為了對一個畫素分類,使用該畫素周圍的一個影象塊作為CNN的輸入,用於訓練與預測,這種方法主要有幾個缺點:

1)儲存開銷大,例如,對每個畫素使用15 * 15的影象塊,然後不斷滑動視窗,將影象塊輸入到CNN中進行類別判斷,因此,需要的儲存空間隨滑動視窗的次數和大小急劇上升;

2)效率低下,相鄰畫素塊基本上是重複的,針對每個畫素塊逐個計算卷積,這種計算有很大程度上的重複;

3)畫素塊的大小限制了感受區域的大小,通常畫素塊的大小比整幅影象的大小小很多,只能提取一些區域性特徵,從而導致分類效能受到限制。
而全卷積網路(FCN)則是從抽象的特徵中恢復出每個畫素所屬的類別。即從影象級別的分類進一步延伸到畫素級別的分類。

FCN改變了什麼?

對於一般的分類CNN網路,如VGG和Resnet,都會在網路的最後加入一些全連線層,經過softmax後就可以獲得類別概率資訊。但是這個概率資訊是1維的,即只能標識整個圖片的類別,不能標識每個畫素點的類別,所以這種全連線方法不適用於影象分割。

而FCN提出可以把後面幾個全連線都換成卷積,這樣就可以獲得一張2維的feature map,後接softmax層獲得每個畫素點的分類資訊,從而解決了分割問題,如圖。

FCN缺點

(1)得到的結果還是不夠精細。進行8倍上取樣雖然比32倍的效果好了很多,但是上取樣的結果還是比較模糊和平滑,對影象中的細節不敏感。
(2)對各個畫素進行分類,沒有充分考慮畫素與畫素之間的關係。忽略了在通常的基於畫素分類的分割方法中使用的空間規整(spatial regularization)步驟,缺乏空間一致性。

資料集

本例的資料集採用PASCAL VOC 2012 資料集,它有二十個類別:

Person:person

Animal: bird, cat, cow, dog, horse, sheep

Vehicle:aeroplane, bicycle, boat, bus, car, motorbike, train

Indoor: bottle, chair, dining table, potted plant, sofa, tv/monitor

下載地址:The PASCAL Visual Object Classes Challenge 2012 (VOC2012) (ox.ac.uk)

資料集的結構:

VOCdevkit
    └── VOC2012
         ├── Annotations               所有的影象標註資訊(XML檔案)
         ├── ImageSets    
         │   ├── Action                人的行為動作影象資訊
         │   ├── Layout                人的各個部點陣圖像資訊
         │   │
         │   ├── Main                  目標檢測分類影象資訊
         │   │     ├── train.txt       訓練集(5717)
         │   │     ├── val.txt         驗證集(5823)
         │   │     └── trainval.txt    訓練集+驗證集(11540)
         │   │
         │   └── Segmentation          目標分割影象資訊
         │         ├── train.txt       訓練集(1464)
         │         ├── val.txt         驗證集(1449)
         │         └── trainval.txt    訓練集+驗證集(2913)
         │ 
         ├── JPEGImages                所有影象檔案
         ├── SegmentationClass         語義分割png圖(基於類別)
         └── SegmentationObject        例項分割png圖(基於目標)

資料集包含物體檢測和語義分割,我們只需要語義分割的資料集,所以可以考慮把多餘的圖片刪除,刪除的思路:

1、獲取所有圖片的name。

2、獲取所有語義分割mask的name。

3、求二者的差集,然後將差集的name刪除。

程式碼如下:

import glob
import os
image_all = glob.glob('data/VOCdevkit/VOC2012/JPEGImages/*.jpg')
image_all_name = [image_file.replace('\\', '/').split('/')[-1].split('.')[0] for image_file in image_all]

image_SegmentationClass = glob.glob('data/VOCdevkit/VOC2012/SegmentationClass/*.png')
image_se_name= [image_file.replace('\\', '/').split('/')[-1].split('.')[0] for image_file in image_SegmentationClass]
image_other=list(set(image_all_name) - set(image_se_name))
print(image_other)
for image_name in image_other:
    os.remove('data/VOCdevkit/VOC2012/JPEGImages/{}.jpg'.format(image_name))

程式碼連結

本例選用的程式碼來自deep-learning-for-image-processing/pytorch_segmentation/fcn at master · WZMIAOMIAO/deep-learning-for-image-processing (github.com)

其他的程式碼也有很多,這篇比較好理解!

其實還有個比較好的影象分割庫:https://github.com/qubvel/segmentation_models.pytorch

這個影象分割集合由俄羅斯的程式設計師小哥Pavel Yakubovskiy一手打造。在後面的文章,我也會使用這個庫演示。

專案結構

├── src: 模型的backbone以及FCN的搭建
├── train_utils: 訓練、驗證以及多GPU訓練相關模組
├── my_dataset.py: 自定義dataset用於讀取VOC資料集
├── train.py: 以fcn_resnet50(這裡使用了Dilated/Atrous Convolution)進行訓練
├── predict.py: 簡易的預測指令碼,使用訓練好的權重進行預測測試
├── validation.py: 利用訓練好的權重驗證/測試資料的mIoU等指標,並生成record_mAP.txt檔案
└── pascal_voc_classes.json: pascal_voc標籤檔案

由於程式碼很多不能一一講解,所以,接下來對重要的程式碼做剖析。

自定義資料集讀取

my_dataset.py自定義資料讀取的方法,程式碼如下:

import os
import torch.utils.data as data
from PIL import Image

class VOCSegmentation(data.Dataset):
    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        super(VOCSegmentation, self).__init__()
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        root=root.replace('\\','/')
        assert os.path.exists(root), "path '{}' does not exist.".format(root)
        image_dir = os.path.join(root, 'JPEGImages')
        mask_dir = os.path.join(root, 'SegmentationClass')

        txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
        txt_path=txt_path.replace('\\','/')
        assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
        with open(os.path.join(txt_path), "r") as f:
            file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert (len(self.images) == len(self.masks))
        self.transforms = transforms

匯入需要的包。

定義VOC資料集讀取類VOCSegmentation。在init方法中,核心是讀取image列表和mask列表。

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target

__getitem__方法是獲取單張圖片和圖片對應的mask,然後對其做資料增強。

 def collate_fn(batch):
        images, targets = list(zip(*batch))
        batched_imgs = cat_list(images, fill_value=0)
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets

collate_fn方法是對一個batch中資料呼叫cat_list做資料對齊。

在train.py中torch.utils.data.DataLoader呼叫

 train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)
  val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

訓練

重要引數

開啟train.py,我們先認識一下重要的引數:

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="pytorch fcn training")
    # 資料集的根目錄(VOCdevkit)所在的資料夾
    parser.add_argument("--data-path", default="data/", help="VOCdevkit root")
    parser.add_argument("--num-classes", default=20, type=int)
    parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")
    parser.add_argument("--device", default="cuda", help="training device")
    parser.add_argument("-b", "--batch-size", default=32, type=int)
    parser.add_argument("--epochs", default=30, type=int, metavar="N",
                        help="number of total epochs to train")

    parser.add_argument('--lr', default=0.0001, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    # 是否使用混合精度訓練
    parser.add_argument("--amp", default=False, type=bool,
                        help="Use torch.cuda.amp for mixed precision training")

    args = parser.parse_args()

    return args

data-path:定義資料集的根目錄(VOCdevkit)所在的資料夾

num-classes:檢測目標類別數(不包含背景)。

aux:是否使用aux_classifier。

device:使用cpu還是gpu訓練,預設是cuda。

batch-size:BatchSize設定。

epochs:epoch的個數。

lr:學習率。

resume:繼續訓練時候,選擇用的模型。

start-epoch:起始的epoch,針對再次訓練時,可以不需要從0開始。

amp:是否使用torch的自動混合精度訓練。

資料增強

增強呼叫transforms.py中的方法。

訓練集的增強如下:

class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        # 隨機Resize的最小尺寸
        min_size = int(0.5 * base_size)
        # 隨機Resize的最大尺寸
        max_size = int(2.0 * base_size)
        # 隨機Resize增強。
        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            #隨機水平翻轉
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        trans.extend([
            #隨機裁剪
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)

訓練集增強,包括隨機Resize、隨機水平翻轉、隨即裁剪。

驗證集增強:

class SegmentationPresetEval:
    def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose([
            T.RandomResize(base_size, base_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)

驗證集的增強比較簡單,只有隨機Resize。

Main方法

對Main方法,我做了一些修改,修改的程式碼如下:

 #定義模型,並載入預訓練
    model = fcn_resnet50(pretrained=True)
    # 預設classes是21,如果不是21,則要修改類別。
    if num_classes != 21:
        model.classifier[4] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
        model.aux_classifier[4] = torch.nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
    print(model)
    model.to(device)
    # 如果有多張顯示卡,則使用多張顯示卡
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

模型,我改為pytorch官方的模型了,如果能使用官方的模型儘量使用官方的模型。

預設類別是21,如果不是21,則要修改類別。

檢測系統中是否有多張卡,如果有多張卡則使用多張卡不能浪費資源。

如果不想使用所有的卡,而是指定其中的幾張卡,可以使用:

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

也可以在DataParallel方法中設定:

model = torch.nn.DataParallel(model,device_ids=[0,1])

如果使用了多顯示卡,再使用模型的引數就需要改為model.module.xxx,例如:

  params = [p for p in model.module.aux_classifier.parameters() if p.requires_grad]
            params_to_optimize.append({"params": params, "lr": args.lr * 10})

上面的都完成了就可以開始訓練了,如下圖:

測試

在開始測試之前,我們還要獲取到調色盤,新建指令碼get_palette.py,程式碼如下:

import json
import numpy as np
from PIL import Image
# 讀取mask標籤
target = Image.open("./2007_001288.png")
# 獲取調色盤
palette = target.getpalette()

palette = np.reshape(palette, (-1, 3)).tolist()
print(palette)
# 轉換成字典子形式
pd = dict((i, color) for i, color in enumerate(palette))

json_str = json.dumps(pd)
with open("palette.json", "w") as f:
    f.write(json_str)

選取一張mask,然後使用getpalette方法獲取,然後將其轉為字典的格式儲存。

接下來,開始預測部分,新建predict.py,插入以下程式碼:

import os
import time
import json
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from torchvision.models.segmentation import fcn_resnet50

匯入程式需要的包檔案,然在mian方法中:

def main():
    aux = False  # inference time not need aux_classifier
    classes = 20
    weights_path = "./save_weights/model_5.pth"
    img_path = "./2007_000123.jpg"
    palette_path = "./palette.json"
    assert os.path.exists(weights_path), f"weights {weights_path} not found."
    assert os.path.exists(img_path), f"image {img_path} not found."
    assert os.path.exists(palette_path), f"palette {palette_path} not found."
    with open(palette_path, "rb") as f:
        pallette_dict = json.load(f)
        pallette = []
        for v in pallette_dict.values():
            pallette += v
  • 定義是否需要aux_classifier,預測不需要aux_classifier,所以設定為False。
  • 設定類別為20,不包括背景。
  • 定義權重的路徑。
  • 定義調色盤的路徑。
  • 讀去調色盤。

接下來,是載入模型,單顯示卡訓練出來的模型和多顯示卡訓練出來的模型載入有區別,我們先看單顯示卡訓練出來的模型如何載入。

   model = fcn_resnet50(num_classes=classes+1)
    print(model)
    # 單顯示卡訓練出來的模型,載入
    # delete weights about aux_classifier
    weights_dict = torch.load(weights_path, map_location='cpu')['model']
    for k in list(weights_dict.keys()):
        if "aux_classifier" in k:
            del weights_dict[k]

    # load weights
    model.load_state_dict(weights_dict)
    model.to(device)

定義模型fcn_resnet50,num_classes設定為類別+1(背景)

載入訓練好的模型,並將aux_classifier刪除。

然後載入權重。

再看多顯示卡的模型如何載入

    # create model
    model = fcn_resnet50(num_classes=classes+1)
    model = torch.nn.DataParallel(model)
    # delete weights about aux_classifier
    weights_dict = torch.load(weights_path, map_location='cpu')['model']
    print(weights_dict)
    for k in list(weights_dict.keys()):
        if "aux_classifier" in k:
            del weights_dict[k]
    # load weights
    model.load_state_dict(weights_dict)
    model=model.module
    model.to(device)

定義模型fcn_resnet50,num_classes設定為類別+1(背景),將模型放入DataParallel類中。

載入訓練好的模型,並將aux_classifier刪除。

載入權重。

執行torch.nn.DataParallel(model)時,model被放在了model.module,所以model.module才真正需要的模型。所以我們在這裡將model.module賦值給model。

接下來是影象資料的處理

  # load image
    original_img = Image.open(img_path)

    # from pil image to tensor and normalize
    data_transform = transforms.Compose([transforms.Resize(520),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                              std=(0.229, 0.224, 0.225))])
    img = data_transform(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

載入影象。

對影象做Resize、標準化、歸一化處理。

使用torch.unsqueeze增加一個維度。

完成影象的處理後,就可以開始預測了。

	model.eval()  # 進入驗證模式
    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        output = model(img.to(device))
        t_end = time_synchronized()
        print("inference+NMS time: {}".format(t_end - t_start))

        prediction = output['out'].argmax(1).squeeze(0)
        prediction = prediction.to("cpu").numpy().astype(np.uint8)
        np.set_printoptions(threshold=sys.maxsize)
        print(prediction.shape)
        mask = Image.fromarray(prediction)
        mask.putpalette(pallette)
        mask.save("test_result.png")

將預測後的結果儲存到test_result.png中。檢視執行結果:

原圖:

結果:

打印出來的資料:

類別列表:

{
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}

從結果來看,已經預測出來影象上的類別是“train”。

總結

這篇文章的核心內容是講解如何使用FCN實現影象的語義分割。

在文章的開始,我們講了一些FCN的結構和優缺點。然後,講解了如何讀取資料集。接下來,告訴大家如何實現訓練。最後,是測試以及結果展示。希望本文能給大家帶來幫助。

完整程式碼:https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/83778007

 

點選關注,第一時間瞭解華為雲新鮮技術~