Swin-Unet分割網路
Swin-Unet是基於Swin Transformer為基礎(可參考Swin Transformer介紹 ),結合了U-Net網路的特點(可參考Tensorflow深度學習演算法整理(三) 中的U-Net)組合而成的新的分割網路
它與Swin Transformer不同的地方在於,在編碼器(Encoder)這邊雖然跟Swin Transformer一樣的4個Stage,但Swin Transformer Block的數量為[2,2,2,2],而不是Swin Transformer的[2,2,6,2]。而在解碼器(Decoder)這邊,由於是升取樣,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆過程。
我們來看一下Patch Expanding的程式碼實現
from einops import rearrange
class PatchExpand(nn.Module): """ 塊狀擴充,尺寸翻倍,通道數減半 """ def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): """ Args: input_resolution: 解碼過程的feature map的寬高 dim: frature map通道數 dim_scale: 通道數擴充的倍數 norm_layer: 通道方向歸一化 """ super().__init__() self.input_resolution = input_resolution self.dim = dim # 通過全連線層來擴大通道數 self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() self.norm = norm_layer(dim // dim_scale) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution # 先把通道數翻倍 x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # 將各個通道分開,再將所有通道拼成一個feature map # 增大了feature map的尺寸 x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 通道翻倍後再除以4,實際相當於通道數減半 x = x.view(B, -1, C // 4) x = self.norm(x) return x
在編碼器這邊基本上跟Swin Transformer是一樣的,我們重點來看解碼器這邊。它是使用BasicLayer_up類來對SwinTransformerBlock和Patch Expanding來進行搭配的。
class BasicLayer_up(nn.Module): """ A basic Swin Transformer layer for one stage. 一個BasicLayer_up包含偶數個SwinTransformerBlock和一個upsamele層(即Patch Expanding層) """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): """ Args: dim: feature map通道數 input_resolution: feature map的寬高 depth: 各個Stage中,Swin Transformer Block的數量 num_heads: 多頭注意力各個Stage中的頭數 window_size: 視窗自注意力機制的視窗中的patch數 mlp_ratio: 層感知機模組中第一個全連線層輸出的通道倍數 qkv_bias: 如果是True的話,對自注意力公式中的Q、K、V增加一個可學習的偏置 qk_scale: 視窗自注意力公式常數 drop: dropout rate,預設為0 attn_drop: 用於自注意力機制中的dropout rate,預設為0 drop_path: 在Swin Transformer Block中,有一定概率丟棄整個直連分支,包括 LN、W-MSA或者SW-MSA,只保留直連的連線,是一種網路深度的隨機性,預設為0 norm_layer: 通道方向歸一化 upsample: 使用Patch Expanding來升取樣 use_checkpoint: 是否使用Pytorch中間資料儲存機制 """ super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build SwinTransformerBlock self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, # 用於區分是使用W-MSA還是SW-MSA,0為W-MSA,1為SW-MSA shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer # 當stage=4的時候為None if upsample is not None: self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) else: self.upsample = None def forward(self, x): # 通過每一個SwinTransformerBlock for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) # 進行塊狀擴充(PatchExpanding)上取樣 if self.upsample is not None: x = self.upsample(x) return x
SwinTransformerBlock跟SwinTransformer中的程式碼也是一樣的,這裡就不重複了。
然後還有一個從編碼器到解碼器之間的跳連。這裡需要看一下Swin-Unet的主類程式碼
class SwinTransformerSys(nn.Module): """ Swin-UNet網路模型 """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, final_upsample="expand_first", **kwargs): """ Args: img_size: 原始影象尺寸 patch_size: 一個patch中的畫素點數 in_chans: 進入網路的圖片通道數 num_classes: 分類數量 embed_dim: feature map通道數 depths: 編碼器各個Stage中,Swin Transformer Block的數量 depths_decoder: 解碼器各個Stage中,Swin Transformer Block的數量 num_heads: 多頭注意力各個Stage中的頭數 window_size: 視窗自注意力機制的視窗中的patch數 mlp_ratio: 多層感知機模組中第一個全連線層輸出的通道倍數 qkv_bias: 如果是True的話,對自注意力公式中的Q、K、V增加一個可學習的偏置 qk_scale: 自注意力公式中的常量 drop_rate: dropout rate,預設為0 attn_drop_rate: 用於自注意力機制中的dropout rate,預設為0 drop_path_rate: 在Swin Transformer Block中,有一定概率丟棄整個直連分支,包括 LN、W-MSA或者SW-MSA,只保留直連的連線,是一種網路深度的隨機性,預設為0.1 norm_layer: 通道方向歸一化 ape: 是否進行絕對位置嵌入,預設False patch_norm: 如果是True的話,在patch embedding之後加上歸一化 use_checkpoint: 是否使用Pytorch中間資料儲存機制 final_upsample: 解碼器stage4後的Patch Expanding **kwargs: """ super().__init__() print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, depths_decoder, drop_path_rate, num_classes)) self.num_classes = num_classes # stage的數量 self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm # 編碼器stage4輸出特徵的通道數(Swin-Tiny:768) self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) # 解碼器stage4輸出特徵的通道數(192) self.num_features_up = int(embed_dim * 2) self.mlp_ratio = mlp_ratio self.final_upsample = final_upsample # 把影象分割成不重疊的patch self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches # 獲取feature map的高寬 patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # 絕對位置嵌入 if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # 不同的stage,捨棄整個直連分支的概率不同,從小到大,最小為0,最大為0.1 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # 建立編碼器layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): # layer相當於stage layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, # 只有前3個stage有patchmerging,最後一個沒有 downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) # 建立解碼器layers self.layers_up = nn.ModuleList() self.concat_back_dim = nn.ModuleList() for i_layer in range(self.num_layers): # layer相當於stage # 每一個stage結束後,通道數減半的全連線層 concat_linear = nn.Linear(2 * int(embed_dim * 2**(self.num_layers - 1 - i_layer)), int(embed_dim * 2**(self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() if i_layer == 0: # 第一個stage只進行上取樣 layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) else: layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), depth=depths[(self.num_layers-1-i_layer)], num_heads=num_heads[(self.num_layers-1-i_layer)], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers - 1 - i_layer) + 1])], norm_layer=norm_layer, # 只有前3個stage有PatchExpand,最後一個沒有 upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers_up.append(layer_up) self.concat_back_dim.append(concat_linear) self.norm = norm_layer(self.num_features) self.norm_up = norm_layer(self.embed_dim) # 解碼器最後一個stage進行FinalPatchExpand處理 if self.final_upsample == "expand_first": print("---final upsample expand_first---") self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim) self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) self.apply(self._init_weights)
這裡有一個FinalPatchExpand_X4的方法,我們來看一下它的實現
class FinalPatchExpand_X4(nn.Module): """ stage4之後的PatchExpand 尺寸翻倍,通道數不變 """ def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): """ Args: input_resolution: feature map的寬高 dim: feature map通道數 dim_scale: 通道數擴充的倍數 norm_layer: 通道方向歸一化 """ super().__init__() self.input_resolution = input_resolution self.dim = dim self.dim_scale = dim_scale # 通過全連線層來擴大通道數 self.expand = nn.Linear(dim, 16 * dim, bias=False) self.output_dim = dim self.norm = norm_layer(self.output_dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution # 先把通道數翻倍 x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # 將各個通道分開,再將所有通道拼成一個feature map # 增大了feature map的尺寸 x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) # 把擴大的通道數轉成原來的通道數 x = x.view(B, -1, self.output_dim) x = self.norm(x) return x
回到SwinTransformerSys程式碼中
def _init_weights(self, m): """ 對全連線層或者通道歸一化進行權重以及偏置的初始化 """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} #Encoder and Bottleneck def forward_features(self, x): """ 編碼器過程 """ # 影象分割 x = self.patch_embed(x) # 絕對位置嵌入 if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) # 跳連點 x_downsample = [] # 通過各個編碼過程的stage for layer in self.layers: x_downsample.append(x) x = layer(x) x = self.norm(x) # B L C return x, x_downsample #Dencoder and Skip connection def forward_up_features(self, x, x_downsample): """ 解碼器過程,包含了跳連拼接 """ # 通過各個解碼過程的stage for inx, layer_up in enumerate(self.layers_up): if inx == 0: x = layer_up(x) else: # 拼接編碼器的跳連部分再進入Swin Transformer Block x = torch.cat([x, x_downsample[3-inx]], -1) x = self.concat_back_dim[inx](x) x = layer_up(x) x = self.norm_up(x) # B L C return x def up_x4(self, x): """ 完成解碼器的最後一個stage後進入 """ H, W = self.patches_resolution B, L, C = x.shape assert L == H * W, "input features has wrong size" if self.final_upsample == "expand_first": x = self.up(x) x = x.view(B, 4 * H, 4 * W, -1) x = x.permute(0, 3, 1, 2) #B,C,H,W x = self.output(x) return x def forward(self, x): """ 前向運算 """ x, x_downsample = self.forward_features(x) x = self.forward_up_features(x, x_downsample) x = self.up_x4(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops
接下來就是模型訓練了,這裡我捨棄了原框架的訓練程式碼,使用了之前U-Net類似的程式碼
import torch from torch import optim import torch.nn as nn from torchvision import transforms from torch.utils.data import Dataset, DataLoader, random_split from PIL import Image import numpy as np import matplotlib.pyplot as plt import os import time import copy from torch.utils.tensorboard import SummaryWriter from swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys from utils import DiceLoss RUN_NAME = 'swinunetv1' N_CLASSES = 3 INPUT_SIZE = 128 EPOCHS = 21 LEARNING_RATE = 0.01 START_FRAME = 16 DROP_DATE = 0.5 DATA_PATH = '/media/jingzhi/新加捲/' IMAGE_PATH = 'data_dataset_voc/JPEGImagespng/' MASK_PATH = 'data_dataset_voc/SegmentationClassPNG-new/' TEST_IMAGE_PATH = 'test_dataset_voc/JPEGImagespng/' TEST_MASK_PATH = 'test_dataset_voc/SegmentationClassPNG-new/' IMAGE_TYPE = '.png' MASK_TYPE = '.png' LOG_PATH = './runs' SAVE_PATH = './' REAL_HEIGHT = 3000 REAL_WIDTH = 4096 IMG_HEIGHT = 224 IMG_WIDTH = 224 RANDOM_SEED = 42 VALID_RATIO = 0.2 BATCH_SIZE = 32 NUM_WORKERS = 1 CLASSES = {1: 'line'} class LineDataset(Dataset): def __init__(self, root_dir=DATA_PATH, transform=None): self.root_dir = root_dir listname = [] for imgfile in os.listdir(DATA_PATH + IMAGE_PATH): list = imgfile.split('.') l = len(list) if '.' + list[l - 1] == IMAGE_TYPE: if l > 2: filename = list[0] + '.' + list[1] else: filename = list[0] listname.append(filename) self.ids = listname if transform is None: self.transform1 = transforms.Compose( [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), transforms.ToTensor()]) self.transform2 = transforms.Compose( [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()]) # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) def __len__(self): return len(self.ids) def __getitem__(self, index): id = self.ids[index] image = Image.open(self.root_dir + IMAGE_PATH + id + IMAGE_TYPE) mask = Image.open(self.root_dir + MASK_PATH + id + MASK_TYPE) image = self.transform1(image) mask = self.transform2(mask) return image, mask def get_trainloader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS): train_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers) return train_loader def get_dataloader(dataset, batch_size=BATCH_SIZE, random_seed=RANDOM_SEED, valid_ratio=VALID_RATIO, shuffle=True, num_workers=NUM_WORKERS): error_msg = "[!] valid_ratio should be in the range [0, 1]." assert ((valid_ratio >= 0) and (valid_ratio <= 1)), error_msg n = len(dataset) n_valid = int(valid_ratio * n) n_train = n - n_valid torch.manual_seed(random_seed) train_dataset, valid_dataset = random_split(dataset, (n_train, n_valid)) # train_loader = DataLoader(train_dataset, batch_size, shuffle=shuffle, num_workers=num_workers) valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers) return train_loader, valid_loader def show_dataset(dataset, n_sample=4): plt.figure(figsize=(30, 15)) for i in range(n_sample): image, mask = dataset[i] image = transforms.ToPILImage()(image) mask = transforms.ToPILImage()(mask) print(i, image.size, mask.size) plt.tight_layout() ax = plt.subplot(n_sample, 1, i + 1) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(image, cmap="Greys") plt.imshow(mask, alpha=0.3, cmap="OrRd") if i == n_sample - 1: plt.show() break class Test_LineDataset(Dataset): def __init__(self, root_dir=DATA_PATH, transform=None): self.root_dir = root_dir listname = [] for imgfile in os.listdir(DATA_PATH + TEST_MASK_PATH): if '.' + imgfile.split('.')[1] == MASK_TYPE: filename = imgfile.split('.')[0] listname.append(filename) self.ids = listname if transform is None: self.transform = transforms.Compose([transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()]) # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) def __len__(self): return len(self.ids) def __getitem__(self, index): id = self.ids[index] image = Image.open(self.root_dir + TEST_IMAGE_PATH + id + IMAGE_TYPE) mask = Image.open(self.root_dir + TEST_MASK_PATH + id + MASK_TYPE) image = self.transform(image) mask = self.transform(mask) return image, mask def get_validloader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS): valid_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers) return valid_loader def show_test_dataset(dataset, n_sample=2): plt.figure(figsize=(30, 15)) for i in range(n_sample): image = dataset[i][0] image = transforms.ToPILImage()(image) print(i, image.size) plt.tight_layout() ax = plt.subplot(1, n_sample, i + 1) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(image, cmap="Greys") if i == n_sample - 1: plt.show() break def labels(): l = {} for i, label in enumerate(CLASSES): l[i] = label return l def tensor2np(tensor): tensor = tensor.squeeze().cpu() return tensor.detach().numpy() def normtensor(tensor): tensor = torch.where(tensor < 0., torch.zeros(1).cuda(), torch.ones(1).cuda()) return tensor def count_params(model): pytorch_total_params = sum(p.numel() for p in model.parameters()) return pytorch_total_params def cal_iou(outputs, labels, SMOOTH=1e-6): with torch.no_grad(): outputs = outputs.squeeze(1).bool() labels = labels.squeeze(1).bool() intersection = (outputs & labels).float().sum((1, 2)) union = (outputs | labels).float().sum((1, 2)) iou = (intersection + SMOOTH) / (union + SMOOTH) return iou def get_iou_score(outputs, labels): A = labels.squeeze(1).bool() pred = torch.where(outputs < 0., torch.zeros(1).cuda(), torch.ones(1).cuda()) B = pred.squeeze(1).bool() intersection = (A & B).float().sum((1, 2)) union = (A | B).float().sum((1, 2)) iou = (intersection + 1e-6) / (union + 1e-6) return iou.cpu().detach().numpy() def train(model, device, trainloader, optimizer, loss_function, dice_function, epoch): model.train() # model.is_train = True running_loss = 0 mask_list, iou = [], [] for i, (input, mask) in enumerate(trainloader): input, mask = input.to(device), mask.to(device) predict = model(input) loss_ce = loss_function(predict, mask) loss_dice = dice_function(predict, mask, softmax=True) loss = 0.4 * loss_ce + 0.6 * loss_dice iou.append(get_iou_score(predict, mask).mean()) running_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() if ((i + 1) % 10) == 0: pred = normtensor(predict[0]) img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0]) print(f'Epoch: {epoch} | Item: {i} | Train loss: {loss:.5f}') mean_iou = np.mean(iou) total_loss = running_loss / len(trainloader) writer.add_scalar('training loss', total_loss, epoch) return total_loss, mean_iou def test(model, device, testloader, loss_function, dice_function, best_iou, epoch): model.eval() # model.is_train = False running_loss = 0 mask_list, iou = [], [] with torch.no_grad(): for i, (input, mask) in enumerate(testloader): input, mask = input.to(device), mask.to(device) predict = model(input) loss_ce = loss_function(predict, mask) loss_dice = dice_function(predict, mask, softmax=True) loss = 0.4 * loss_ce + 0.6 * loss_dice running_loss += loss.item() iou.append(get_iou_score(predict, mask).mean()) if ((i + 1) % 1) == 0: pred = normtensor(predict[0]) img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0]) print(f'Epoch: {epoch} | Item: {i} | test loss: {loss:.5f}') test_loss = running_loss / len(testloader) mean_iou = np.mean(iou) writer.add_scalar('val loss', test_loss, epoch) if mean_iou > best_iou: try: torch.save(model.state_dict(), SAVE_PATH + RUN_NAME + '.pth') except: print('Can export weights') return test_loss, mean_iou def model_pipeline(prev_model=None): best_model = None model, criterion1, criterion2, optimizer = make_model(prev_model) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) best_iou = -1 for epoch in range(EPOCHS): t0 = time.time() train_loss, train_iou = train(model, device, trainloader, optimizer, criterion1, criterion2, epoch) t1 = time.time() print(f'Epoch: {epoch} | Train loss: {train_loss:.5f} | Train IoU: {train_iou:.3f} | Time: ' f'{(t1 - t0):.1f}s') t0 = time.time() test_loss, test_iou = test(model, device, validloader, criterion1, criterion2, best_iou, epoch) t1 = time.time() print(f'Epoch: {epoch} | Valid loss: {test_loss:.5f} | Valid IoU: {test_iou:.3f} | Time: ' f'{(t1 - t0):.1f}s') scheduler.step() if best_iou < test_iou: best_iou = test_iou best_model = copy.deepcopy(model) return best_model def make_model(prev_model=None): if prev_model == None: model = SwinTransformerSys().to(device) else: model = prev_model print("Number of parameter:", count_params(model)) criterion1 = nn.BCEWithLogitsLoss() criterion2 = DiceLoss(2) optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0001) return model, criterion1, criterion2, optimizer def predict(model, test_loader, device): model.eval() predicted_masks = [] back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))]) with torch.no_grad(): for i, (input, _) in enumerate(test_loader): input = input.to(device) predict = model(input) predict = back_transform(predict) predict = (predict > 0).type(torch.float) predicted_masks.append(predict) predicted_masks = torch.cat(predicted_masks) return predicted_masks def show_sample_test_result(test_dataset, predicted_mask, n_samples=60): plt.rcParams['figure.figsize'] = (30, 15) back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))]) for i in range(n_samples): sample = predicted_mask[i] sample = torch.squeeze(sample, dim=0) sample = transforms.ToPILImage()(sample) X = test_dataset[i][0] X = back_transform(X) X = transforms.ToPILImage()(X) if (i + 1) % 4 != 0: index = (i + 1) % 4 else: index = 4 ax = plt.subplot(2, 2, index) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(X, cmap="Greys") plt.imshow(sample, alpha=0.7, cmap="winter") # if i == n_samples - 1: if i % 3 == 0 and i != 0: plt.show() # break if __name__ == '__main__': writer = SummaryWriter(LOG_PATH) dataset = LineDataset(DATA_PATH) valid_dataset = Test_LineDataset(DATA_PATH) trainloader = get_trainloader(dataset=dataset) validloader = get_validloader(dataset=valid_dataset) # show_dataset(dataset) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = UNet_ResNet() # model = model.to(device) # model.load_state_dict(torch.load(SAVE_PATH + RUN_NAME + '.pth')) # print(device) model = model_pipeline() writer.close() # predict_mask = predict(model, validloader, device) # show_sample_test_result(valid_dataset, predict_mask)
DiceLoss的程式碼如下
import numpy as np import torch from medpy import metric from scipy.ndimage import zoom import torch.nn as nn import SimpleITK as sitk class DiceLoss(nn.Module): def __init__(self, n_classes): super(DiceLoss, self).__init__() self.n_classes = n_classes def _one_hot_encoder(self, input_tensor): tensor_list = [] for i in range(self.n_classes): temp_prob = input_tensor == i # * torch.ones_like(input_tensor) print(input_tensor.size()) print(temp_prob.size()) tensor_list.append(temp_prob.unsqueeze(1)) output_tensor = torch.cat(tensor_list, dim=1) return output_tensor.float() def _dice_loss(self, score, target): target = target.float() smooth = 1e-5 intersect = torch.sum(score * target) y_sum = torch.sum(target * target) z_sum = torch.sum(score * score) loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) loss = 1 - loss return loss def forward(self, inputs, target, weight=None, softmax=False): if softmax: inputs = torch.softmax(inputs, dim=1) # target = self._one_hot_encoder(target) if weight is None: weight = [1] * self.n_classes assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) class_wise_dice = [] loss = 0.0 for i in range(0, self.n_classes): dice = self._dice_loss(inputs[:, i], target[:, i]) class_wise_dice.append(1.0 - dice.item()) loss += dice * weight[i] return loss / self.n_classes def calculate_metric_percase(pred, gt): pred[pred > 0] = 1 gt[gt > 0] = 1 if pred.sum() > 0 and gt.sum()>0: dice = metric.binary.dc(pred, gt) hd95 = metric.binary.hd95(pred, gt) return dice, hd95 elif pred.sum() > 0 and gt.sum()==0: return 1, 0 else: return 0, 0 def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1): image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() if len(image.shape) == 3: prediction = np.zeros_like(label) for ind in range(image.shape[0]): slice = image[ind, :, :] x, y = slice.shape[0], slice.shape[1] if x != patch_size[0] or y != patch_size[1]: slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() net.eval() with torch.no_grad(): outputs = net(input) out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) out = out.cpu().detach().numpy() if x != patch_size[0] or y != patch_size[1]: pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) else: pred = out prediction[ind] = pred else: input = torch.from_numpy(image).unsqueeze( 0).unsqueeze(0).float().cuda() net.eval() with torch.no_grad(): out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) prediction = out.cpu().detach().numpy() metric_list = [] for i in range(1, classes): metric_list.append(calculate_metric_percase(prediction == i, label == i)) if test_save_path is not None: img_itk = sitk.GetImageFromArray(image.astype(np.float32)) prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) img_itk.SetSpacing((1, 1, z_spacing)) prd_itk.SetSpacing((1, 1, z_spacing)) lab_itk.SetSpacing((1, 1, z_spacing)) sitk.WriteImage(prd_itk, test_save_path + '/' + case + "_pred.nii.gz") sitk.WriteImage(img_itk, test_save_path + '/' + case + "_img.nii.gz") sitk.WriteImage(lab_itk, test_save_path + '/' + case + "_gt.nii.gz") return metric_list
「其他文章」