Swin-Unet分割網路

語言: CN / TW / HK

Swin-Unet是基於Swin Transformer為基礎(可參考Swin Transformer介紹 ),結合了U-Net網路的特點(可參考Tensorflow深度學習演算法整理(三) 中的U-Net)組合而成的新的分割網路

它與Swin Transformer不同的地方在於,在編碼器(Encoder)這邊雖然跟Swin Transformer一樣的4個Stage,但Swin Transformer Block的數量為[2,2,2,2],而不是Swin Transformer的[2,2,6,2]。而在解碼器(Decoder)這邊,由於是升取樣,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆過程。

我們來看一下Patch Expanding的程式碼實現

from einops import rearrange
class PatchExpand(nn.Module):
    """
    塊狀擴充,尺寸翻倍,通道數減半
    """
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        """
        Args:
            input_resolution: 解碼過程的feature map的寬高
            dim: frature map通道數
            dim_scale: 通道數擴充的倍數
            norm_layer: 通道方向歸一化
        """
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        # 通過全連線層來擴大通道數
        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        # 先把通道數翻倍
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        # 將各個通道分開,再將所有通道拼成一個feature map
        # 增大了feature map的尺寸
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
        # 通道翻倍後再除以4,實際相當於通道數減半
        x = x.view(B, -1, C // 4)
        x = self.norm(x)

        return x

在編碼器這邊基本上跟Swin Transformer是一樣的,我們重點來看解碼器這邊。它是使用BasicLayer_up類來對SwinTransformerBlock和Patch Expanding來進行搭配的。

class BasicLayer_up(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    一個BasicLayer_up包含偶數個SwinTransformerBlock和一個upsamele層(即Patch Expanding層)
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
        """
        Args:
            dim: feature map通道數
            input_resolution: feature map的寬高
            depth: 各個Stage中,Swin Transformer Block的數量
            num_heads: 多頭注意力各個Stage中的頭數
            window_size: 視窗自注意力機制的視窗中的patch數
            mlp_ratio: 層感知機模組中第一個全連線層輸出的通道倍數
            qkv_bias: 如果是True的話,對自注意力公式中的Q、K、V增加一個可學習的偏置
            qk_scale: 視窗自注意力公式常數
            drop: dropout rate,預設為0
            attn_drop: 用於自注意力機制中的dropout rate,預設為0
            drop_path: 在Swin Transformer Block中,有一定概率丟棄整個直連分支,包括
                       LN、W-MSA或者SW-MSA,只保留直連的連線,是一種網路深度的隨機性,預設為0
            norm_layer: 通道方向歸一化
            upsample: 使用Patch Expanding來升取樣
            use_checkpoint: 是否使用Pytorch中間資料儲存機制
        """

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build SwinTransformerBlock
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 # 用於區分是使用W-MSA還是SW-MSA,0為W-MSA,1為SW-MSA
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        # 當stage=4的時候為None
        if upsample is not None:
            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
        else:
            self.upsample = None

    def forward(self, x):
        # 通過每一個SwinTransformerBlock
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        # 進行塊狀擴充(PatchExpanding)上取樣
        if self.upsample is not None:
            x = self.upsample(x)
        return x

SwinTransformerBlock跟SwinTransformer中的程式碼也是一樣的,這裡就不重複了。

然後還有一個從編碼器到解碼器之間的跳連。這裡需要看一下Swin-Unet的主類程式碼

class SwinTransformerSys(nn.Module):
    """ Swin-UNet網路模型
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, final_upsample="expand_first", **kwargs):
        """
        Args:
            img_size: 原始影象尺寸
            patch_size: 一個patch中的畫素點數
            in_chans: 進入網路的圖片通道數
            num_classes: 分類數量
            embed_dim: feature map通道數
            depths: 編碼器各個Stage中,Swin Transformer Block的數量
            depths_decoder: 解碼器各個Stage中,Swin Transformer Block的數量
            num_heads: 多頭注意力各個Stage中的頭數
            window_size: 視窗自注意力機制的視窗中的patch數
            mlp_ratio: 多層感知機模組中第一個全連線層輸出的通道倍數
            qkv_bias: 如果是True的話,對自注意力公式中的Q、K、V增加一個可學習的偏置
            qk_scale: 自注意力公式中的常量
            drop_rate: dropout rate,預設為0
            attn_drop_rate: 用於自注意力機制中的dropout rate,預設為0
            drop_path_rate: 在Swin Transformer Block中,有一定概率丟棄整個直連分支,包括
                            LN、W-MSA或者SW-MSA,只保留直連的連線,是一種網路深度的隨機性,預設為0.1
            norm_layer: 通道方向歸一化
            ape: 是否進行絕對位置嵌入,預設False
            patch_norm: 如果是True的話,在patch embedding之後加上歸一化
            use_checkpoint: 是否使用Pytorch中間資料儲存機制
            final_upsample: 解碼器stage4後的Patch Expanding
            **kwargs:
        """
        super().__init__()

        print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
              depths_decoder, drop_path_rate, num_classes))

        self.num_classes = num_classes
        # stage的數量
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        # 編碼器stage4輸出特徵的通道數(Swin-Tiny:768)
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        # 解碼器stage4輸出特徵的通道數(192)
        self.num_features_up = int(embed_dim * 2)
        self.mlp_ratio = mlp_ratio
        self.final_upsample = final_upsample

        # 把影象分割成不重疊的patch
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        # 獲取feature map的高寬
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # 絕對位置嵌入
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # 不同的stage,捨棄整個直連分支的概率不同,從小到大,最小為0,最大為0.1
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # 建立編碼器layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):  # layer相當於stage
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               # 只有前3個stage有patchmerging,最後一個沒有
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)
        
        # 建立解碼器layers
        self.layers_up = nn.ModuleList()
        self.concat_back_dim = nn.ModuleList()
        for i_layer in range(self.num_layers):  # layer相當於stage
            # 每一個stage結束後,通道數減半的全連線層
            concat_linear = nn.Linear(2 * int(embed_dim * 2**(self.num_layers - 1 - i_layer)),
                                      int(embed_dim * 2**(self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
            if i_layer == 0:  # 第一個stage只進行上取樣
                layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
                                       patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
            else:
                layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
                                         input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
                                                           patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
                                         depth=depths[(self.num_layers-1-i_layer)],
                                         num_heads=num_heads[(self.num_layers-1-i_layer)],
                                         window_size=window_size,
                                         mlp_ratio=self.mlp_ratio,
                                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                                         drop=drop_rate, attn_drop=attn_drop_rate,
                                         drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers - 1 - i_layer) + 1])],
                                         norm_layer=norm_layer,
                                         # 只有前3個stage有PatchExpand,最後一個沒有
                                         upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
                                         use_checkpoint=use_checkpoint)
            self.layers_up.append(layer_up)
            self.concat_back_dim.append(concat_linear)

        self.norm = norm_layer(self.num_features)
        self.norm_up = norm_layer(self.embed_dim)
        # 解碼器最後一個stage進行FinalPatchExpand處理
        if self.final_upsample == "expand_first":
            print("---final upsample expand_first---")
            self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim)
            self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)

        self.apply(self._init_weights)

這裡有一個FinalPatchExpand_X4的方法,我們來看一下它的實現

class FinalPatchExpand_X4(nn.Module):
    """
    stage4之後的PatchExpand
    尺寸翻倍,通道數不變
    """
    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        """
        Args:
            input_resolution: feature map的寬高
            dim: feature map通道數
            dim_scale: 通道數擴充的倍數
            norm_layer: 通道方向歸一化
        """
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.dim_scale = dim_scale
        # 通過全連線層來擴大通道數
        self.expand = nn.Linear(dim, 16 * dim, bias=False)
        self.output_dim = dim 
        self.norm = norm_layer(self.output_dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        # 先把通道數翻倍
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        # 將各個通道分開,再將所有通道拼成一個feature map
        # 增大了feature map的尺寸
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
        # 把擴大的通道數轉成原來的通道數
        x = x.view(B, -1, self.output_dim)
        x = self.norm(x)

        return x

回到SwinTransformerSys程式碼中

def _init_weights(self, m):
    """
    對全連線層或者通道歸一化進行權重以及偏置的初始化
    """
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

@torch.jit.ignore
def no_weight_decay(self):
    return {'absolute_pos_embed'}

@torch.jit.ignore
def no_weight_decay_keywords(self):
    return {'relative_position_bias_table'}

#Encoder and Bottleneck
def forward_features(self, x):
    """
    編碼器過程
    """
    # 影象分割
    x = self.patch_embed(x)
    # 絕對位置嵌入
    if self.ape:
        x = x + self.absolute_pos_embed
    x = self.pos_drop(x)
    # 跳連點
    x_downsample = []
    # 通過各個編碼過程的stage
    for layer in self.layers:
        x_downsample.append(x)
        x = layer(x)

    x = self.norm(x)  # B L C

    return x, x_downsample

#Dencoder and Skip connection
def forward_up_features(self, x, x_downsample):
    """
    解碼器過程,包含了跳連拼接
    """
    # 通過各個解碼過程的stage
    for inx, layer_up in enumerate(self.layers_up):
        if inx == 0:
            x = layer_up(x)
        else:
            # 拼接編碼器的跳連部分再進入Swin Transformer Block
            x = torch.cat([x, x_downsample[3-inx]], -1)
            x = self.concat_back_dim[inx](x)
            x = layer_up(x)

    x = self.norm_up(x)  # B L C

    return x

def up_x4(self, x):
    """
    完成解碼器的最後一個stage後進入
    """
    H, W = self.patches_resolution
    B, L, C = x.shape
    assert L == H * W, "input features has wrong size"

    if self.final_upsample == "expand_first":
        x = self.up(x)
        x = x.view(B, 4 * H, 4 * W, -1)
        x = x.permute(0, 3, 1, 2) #B,C,H,W
        x = self.output(x)
        
    return x

def forward(self, x):
    """
    前向運算
    """
    x, x_downsample = self.forward_features(x)
    x = self.forward_up_features(x, x_downsample)
    x = self.up_x4(x)

    return x

def flops(self):
    flops = 0
    flops += self.patch_embed.flops()
    for i, layer in enumerate(self.layers):
        flops += layer.flops()
    flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
    flops += self.num_features * self.num_classes
    return flops

接下來就是模型訓練了,這裡我捨棄了原框架的訓練程式碼,使用了之前U-Net類似的程式碼

import torch
from torch import optim
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import copy
from torch.utils.tensorboard import SummaryWriter
from swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
from utils import DiceLoss

RUN_NAME = 'swinunetv1'
N_CLASSES = 3
INPUT_SIZE = 128
EPOCHS = 21
LEARNING_RATE = 0.01
START_FRAME = 16
DROP_DATE = 0.5

DATA_PATH = '/media/jingzhi/新加捲/'
IMAGE_PATH = 'data_dataset_voc/JPEGImagespng/'
MASK_PATH = 'data_dataset_voc/SegmentationClassPNG-new/'
TEST_IMAGE_PATH = 'test_dataset_voc/JPEGImagespng/'
TEST_MASK_PATH = 'test_dataset_voc/SegmentationClassPNG-new/'
IMAGE_TYPE = '.png'
MASK_TYPE = '.png'
LOG_PATH = './runs'
SAVE_PATH = './'

REAL_HEIGHT = 3000
REAL_WIDTH = 4096
IMG_HEIGHT = 224
IMG_WIDTH = 224
RANDOM_SEED = 42
VALID_RATIO = 0.2
BATCH_SIZE = 32
NUM_WORKERS = 1
CLASSES = {1: 'line'}

class LineDataset(Dataset):

    def __init__(self, root_dir=DATA_PATH, transform=None):
        self.root_dir = root_dir
        listname = []
        for imgfile in os.listdir(DATA_PATH + IMAGE_PATH):
            list = imgfile.split('.')
            l = len(list)
            if '.' + list[l - 1] == IMAGE_TYPE:
                if l > 2:
                    filename = list[0] + '.' + list[1]
                else:
                    filename = list[0]
                listname.append(filename)
        self.ids = listname
        if transform is None:
            self.transform1 = transforms.Compose(
                [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),
                 transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
                 transforms.ToTensor()])
            self.transform2 = transforms.Compose(
                [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),
                 transforms.ToTensor()])
                                                 # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        id = self.ids[index]
        image = Image.open(self.root_dir + IMAGE_PATH + id + IMAGE_TYPE)
        mask = Image.open(self.root_dir + MASK_PATH + id + MASK_TYPE)
        image = self.transform1(image)
        mask = self.transform2(mask)

        return image, mask


def get_trainloader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS):
    train_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)

    return train_loader


def get_dataloader(dataset, batch_size=BATCH_SIZE, random_seed=RANDOM_SEED,
                   valid_ratio=VALID_RATIO, shuffle=True, num_workers=NUM_WORKERS):
    error_msg = "[!] valid_ratio should be in the range [0, 1]."
    assert ((valid_ratio >= 0) and (valid_ratio <= 1)), error_msg

    n = len(dataset)
    n_valid = int(valid_ratio * n)
    n_train = n - n_valid

    torch.manual_seed(random_seed)
    train_dataset, valid_dataset = random_split(dataset, (n_train, n_valid))
    #
    train_loader = DataLoader(train_dataset, batch_size, shuffle=shuffle, num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, valid_loader


def show_dataset(dataset, n_sample=4):
    plt.figure(figsize=(30, 15))
    for i in range(n_sample):
        image, mask = dataset[i]
        image = transforms.ToPILImage()(image)
        mask = transforms.ToPILImage()(mask)
        print(i, image.size, mask.size)

        plt.tight_layout()
        ax = plt.subplot(n_sample, 1, i + 1)
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')

        plt.imshow(image, cmap="Greys")
        plt.imshow(mask, alpha=0.3, cmap="OrRd")

        if i == n_sample - 1:
            plt.show()
            break


class Test_LineDataset(Dataset):

    def __init__(self, root_dir=DATA_PATH, transform=None):
        self.root_dir = root_dir
        listname = []
        for imgfile in os.listdir(DATA_PATH + TEST_MASK_PATH):
            if '.' + imgfile.split('.')[1] == MASK_TYPE:
                filename = imgfile.split('.')[0]
                listname.append(filename)
        self.ids = listname
        if transform is None:
            self.transform = transforms.Compose([transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),
                                                 transforms.ToTensor()])
                                                 # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        id = self.ids[index]
        image = Image.open(self.root_dir + TEST_IMAGE_PATH + id + IMAGE_TYPE)
        mask = Image.open(self.root_dir + TEST_MASK_PATH + id + MASK_TYPE)

        image = self.transform(image)
        mask = self.transform(mask)
        return image, mask


def get_validloader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS):
    valid_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)

    return valid_loader


def show_test_dataset(dataset, n_sample=2):
    plt.figure(figsize=(30, 15))
    for i in range(n_sample):
        image = dataset[i][0]
        image = transforms.ToPILImage()(image)
        print(i, image.size)

        plt.tight_layout()
        ax = plt.subplot(1, n_sample, i + 1)
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')

        plt.imshow(image, cmap="Greys")
        if i == n_sample - 1:
            plt.show()
            break


def labels():
    l = {}
    for i, label in enumerate(CLASSES):
        l[i] = label
    return l


def tensor2np(tensor):
    tensor = tensor.squeeze().cpu()
    return tensor.detach().numpy()


def normtensor(tensor):
    tensor = torch.where(tensor < 0., torch.zeros(1).cuda(), torch.ones(1).cuda())
    return tensor


def count_params(model):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    return pytorch_total_params


def cal_iou(outputs, labels, SMOOTH=1e-6):
    with torch.no_grad():
        outputs = outputs.squeeze(1).bool()
        labels = labels.squeeze(1).bool()

        intersection = (outputs & labels).float().sum((1, 2))
        union = (outputs | labels).float().sum((1, 2))

        iou = (intersection + SMOOTH) / (union + SMOOTH)
    return iou


def get_iou_score(outputs, labels):
    A = labels.squeeze(1).bool()
    pred = torch.where(outputs < 0., torch.zeros(1).cuda(), torch.ones(1).cuda())
    B = pred.squeeze(1).bool()
    intersection = (A & B).float().sum((1, 2))
    union = (A | B).float().sum((1, 2))
    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou.cpu().detach().numpy()


def train(model, device, trainloader, optimizer, loss_function, dice_function, epoch):
    model.train()
    # model.is_train = True
    running_loss = 0
    mask_list, iou = [], []

    for i, (input, mask) in enumerate(trainloader):
        input, mask = input.to(device), mask.to(device)
        predict = model(input)
        loss_ce = loss_function(predict, mask)
        loss_dice = dice_function(predict, mask, softmax=True)
        loss = 0.4 * loss_ce + 0.6 * loss_dice
        iou.append(get_iou_score(predict, mask).mean())
        running_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if ((i + 1) % 10) == 0:
            pred = normtensor(predict[0])
            img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0])
            print(f'Epoch: {epoch} | Item: {i} | Train loss: {loss:.5f}')

    mean_iou = np.mean(iou)
    total_loss = running_loss / len(trainloader)
    writer.add_scalar('training loss', total_loss, epoch)
    return total_loss, mean_iou


def test(model, device, testloader, loss_function, dice_function, best_iou, epoch):
    model.eval()
    # model.is_train = False
    running_loss = 0
    mask_list, iou = [], []
    with torch.no_grad():
        for i, (input, mask) in enumerate(testloader):
            input, mask = input.to(device), mask.to(device)
            predict = model(input)
            loss_ce = loss_function(predict, mask)
            loss_dice = dice_function(predict, mask, softmax=True)
            loss = 0.4 * loss_ce + 0.6 * loss_dice
            running_loss += loss.item()
            iou.append(get_iou_score(predict, mask).mean())

            if ((i + 1) % 1) == 0:
                pred = normtensor(predict[0])
                img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0])
                print(f'Epoch: {epoch} | Item: {i} | test loss: {loss:.5f}')

    test_loss = running_loss / len(testloader)
    mean_iou = np.mean(iou)
    writer.add_scalar('val loss', test_loss, epoch)
    if mean_iou > best_iou:
        try:
            torch.save(model.state_dict(), SAVE_PATH + RUN_NAME + '.pth')
        except:
            print('Can export weights')
    return test_loss, mean_iou


def model_pipeline(prev_model=None):
    best_model = None
    model, criterion1, criterion2, optimizer = make_model(prev_model)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    best_iou = -1
    for epoch in range(EPOCHS):
        t0 = time.time()
        train_loss, train_iou = train(model, device, trainloader, optimizer, criterion1, criterion2, epoch)
        t1 = time.time()
        print(f'Epoch: {epoch} | Train loss: {train_loss:.5f} | Train IoU: {train_iou:.3f} | Time: '
              f'{(t1 - t0):.1f}s')
        t0 = time.time()
        test_loss, test_iou = test(model, device, validloader, criterion1, criterion2, best_iou, epoch)
        t1 = time.time()
        print(f'Epoch: {epoch} | Valid loss: {test_loss:.5f} | Valid IoU: {test_iou:.3f} | Time: '
              f'{(t1 - t0):.1f}s')
        scheduler.step()
        if best_iou < test_iou:
            best_iou = test_iou
            best_model = copy.deepcopy(model)
    return best_model


def make_model(prev_model=None):
    if prev_model == None:
        model = SwinTransformerSys().to(device)
    else:
        model = prev_model
    print("Number of parameter:", count_params(model))
    criterion1 = nn.BCEWithLogitsLoss()
    criterion2 = DiceLoss(2)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0001)
    return model, criterion1, criterion2, optimizer

def predict(model, test_loader, device):
    model.eval()
    predicted_masks = []
    back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))])
    with torch.no_grad():
        for i, (input, _) in enumerate(test_loader):
            input = input.to(device)
            predict = model(input)
            predict = back_transform(predict)
            predict = (predict > 0).type(torch.float)
            predicted_masks.append(predict)
    predicted_masks = torch.cat(predicted_masks)
    return predicted_masks


def show_sample_test_result(test_dataset, predicted_mask, n_samples=60):
    plt.rcParams['figure.figsize'] = (30, 15)
    back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))])
    for i in range(n_samples):
        sample = predicted_mask[i]
        sample = torch.squeeze(sample, dim=0)
        sample = transforms.ToPILImage()(sample)
        X = test_dataset[i][0]
        X = back_transform(X)
        X = transforms.ToPILImage()(X)

        if (i + 1) % 4 != 0:
            index = (i + 1) % 4
        else:
            index = 4
        ax = plt.subplot(2, 2, index)
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')
        plt.imshow(X, cmap="Greys")
        plt.imshow(sample, alpha=0.7, cmap="winter")
        # if i == n_samples - 1:
        if i % 3 == 0 and i != 0:
            plt.show()
            # break


if __name__ == '__main__':

    writer = SummaryWriter(LOG_PATH)
    dataset = LineDataset(DATA_PATH)
    valid_dataset = Test_LineDataset(DATA_PATH)
    trainloader = get_trainloader(dataset=dataset)
    validloader = get_validloader(dataset=valid_dataset)
    # show_dataset(dataset)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model = UNet_ResNet()
    # model = model.to(device)
    # model.load_state_dict(torch.load(SAVE_PATH + RUN_NAME + '.pth'))

    # print(device)
    model = model_pipeline()
    writer.close()
    # predict_mask = predict(model, validloader, device)
    # show_sample_test_result(valid_dataset, predict_mask)

DiceLoss的程式碼如下

import numpy as np
import torch
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitk


class DiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
            print(input_tensor.size())
            print(temp_prob.size())
            tensor_list.append(temp_prob.unsqueeze(1))
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        # target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
        class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes


def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0 and gt.sum()>0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    elif pred.sum() > 0 and gt.sum()==0:
        return 1, 0
    else:
        return 0, 0


def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    if len(image.shape) == 3:
        prediction = np.zeros_like(label)
        for ind in range(image.shape[0]):
            slice = image[ind, :, :]
            x, y = slice.shape[0], slice.shape[1]
            if x != patch_size[0] or y != patch_size[1]:
                slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3)  # previous using 0
            input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
            net.eval()
            with torch.no_grad():
                outputs = net(input)
                out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
                out = out.cpu().detach().numpy()
                if x != patch_size[0] or y != patch_size[1]:
                    pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
                else:
                    pred = out
                prediction[ind] = pred
    else:
        input = torch.from_numpy(image).unsqueeze(
            0).unsqueeze(0).float().cuda()
        net.eval()
        with torch.no_grad():
            out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
            prediction = out.cpu().detach().numpy()
    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(prediction == i, label == i))

    if test_save_path is not None:
        img_itk = sitk.GetImageFromArray(image.astype(np.float32))
        prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
        lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
        img_itk.SetSpacing((1, 1, z_spacing))
        prd_itk.SetSpacing((1, 1, z_spacing))
        lab_itk.SetSpacing((1, 1, z_spacing))
        sitk.WriteImage(prd_itk, test_save_path + '/' + case + "_pred.nii.gz")
        sitk.WriteImage(img_itk, test_save_path + '/' + case + "_img.nii.gz")
        sitk.WriteImage(lab_itk, test_save_path + '/' + case + "_gt.nii.gz")
    return metric_list