PyTorch資料集處理

語言: CN / TW / HK

theme: nico highlight: agate


持續創作,加速成長!這是我參與「掘金日新計劃 · 10 月更文挑戰」的第12天,點選檢視活動詳情

資料樣本處理的程式碼可能會變得雜亂且難以維護,因此理想狀態下我們應該將模型訓練的程式碼和資料集程式碼分開封裝,以獲得更好的程式碼可讀性和模組化程式碼。

PyTorch 提供了兩個基本方法 torch.utils.data.DataLoadertorch.utils.data.Dataset可以讓你預載入資料集或者你的資料。

Dataset儲存樣本及其相關的標籤, DataLoader封裝了關於 Dataset的迭代器,讓我們可以方便地讀取樣本。

PyTorch庫中也提供了一些常用的資料集可以方便使用者做預載入可以通過torch.utils.data.Dataset呼叫,還提供了一些對應資料集的方法。它們可以用於模型的原型和基準測試。

詳細可以戳這裡:


載入資料集

接下來我們看一下怎麼從TorchVision載入Fashion-MNIST資料集。

Fashion-MNIST是Zalando的一個數據集,包含6萬個訓練樣例和1萬個測試樣例。

每個樣例由兩部分組成,一個28×28灰度影象和一個十分類標籤中的某一個標籤。

我們要載入 FashionMNIST Dataset需要用到以下幾個引數: - root 資料集的儲存地址 - train 指定你要取訓練集還是測試集 - download=True 如果你指定的 root中沒有資料集,會自動從網上下載資料集 - transformtarget_transform 指定特徵和標籤轉換

下邊這段程式碼是取FashionMNIST的訓練集和測試集,root設定了一個data檔案,執行下邊這段程式碼以後你可以看到當前目錄下邊應該多了一個data資料夾,裡邊就是FashionMNIST資料集檔案了。

```py import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() )

test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() ) ```

迭代和視覺化資料集

我們可以像列表索引一樣檢視Datasets。 可以使用matplotlib視覺化我們的資料集。

其他程式碼解析看註釋。

至於畫子圖有兩個方法,二者的區別僅在於一個面向方法,一個面向物件,別的完全一樣。

  1. subplot ```py figure = plt.figure() cols, rows = 3, 3 for i in range(1, cols * rows + 1): plt.subplot(rows, cols, i)

    plt.show() ```

  2. add_subplot ```py figure = plt.figure() cols, rows = 3, 3 for i in range(1, cols * rows + 1): figure.subplot(rows, cols, i)

    plt.show() ```

py labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } figure = plt.figure(figsize=(8, 8)) cols, rows = 3, 3 for i in range(1, cols * rows + 1): sample_idx = torch.randint(len(training_data), size=(1,)).item() # 從資料集中隨機取樣 img, label = training_data[sample_idx] # 取得資料集的圖和標籤 figure.add_subplot(rows, cols, i) # 畫子圖,也可以plt.subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") # 是黑白圖,這裡做一個維度壓縮,把1通道的1壓縮掉 plt.show()

最後隨機取樣的結果大概是這樣的:

微信截圖_20220926200941.png


使用DataLoader

Dataset可以檢索我們資料集中一個樣本的特徵和標籤。但是在訓練模型的時候,我們通常希望資料以小批量(minibatch)的方式作為輸入,在每個epoch中重新調整資料以防止過擬合,並且還能使用Python的multiprocessing加速資料檢索。

DataLoader是一個迭代器,將剛才提到的複雜方法抽象成簡單的API。

```py from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True) ```

通過DataLoader迭代獲取資料

我們已經將資料集載入到DataLoader中,並可以根據需要迭代資料集。

下面的每次迭代返回一個批量資料的train_featurestrain_labels(分別包含batch_size=64個特徵和標籤)。

因為我們指定了shuffle=True,在遍歷所有批量之後,資料會被打亂(要對資料載入順序進行更細粒度的控制,戳這裡https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler 。

```py

Display image and label.

train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}") ```


為你的資料建立自定義資料集

自定義Dataset類必須實現三個函式:__init____len____getitem__。看看這個FashionMNIST影象儲存在img_dir目錄中,它們的標籤單獨儲存在CSV檔案annotations_file中。 在下一節我們詳細分析一下每個函式中發生的事情。

```py import os import pandas as pd from torchvision.io import read_image

class CustomImageDataset(Dataset): def init(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform

def __len__(self):
    return len(self.img_labels)

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

```

init

__init__函式在例項化Dataset物件時執行一次,幫我們初始化一個目錄,其中包含影象、註釋檔案和兩個變換(下一節將詳細介紹)。

The labels.csv file looks like:

tshirt1.jpg, 0

tshirt2.jpg, 0

......

ankleboot999.jpg, 9

py def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform

len

__len__方法返回我們資料集中的樣本數量。

py def __len__(self): return len(self.img_labels)

getitem

__getitem__函式當你給定一個索引idx的時候,用於載入並返回樣本。

基於索引,該函式去尋找影象在磁碟上的位置,使用read_image 將其轉換為一個張量,從self中的csv資料中檢索相應的標籤img_labels,呼叫它們上的變換函式(如果適用),並返回一個元組,元組中是影象的張量和對應的標籤。

py def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label