PyTorch數據集處理
theme: nico highlight: agate
持續創作,加速成長!這是我參與「掘金日新計劃 · 10 月更文挑戰」的第12天,點擊查看活動詳情
數據樣本處理的代碼可能會變得雜亂且難以維護,因此理想狀態下我們應該將模型訓練的代碼和數據集代碼分開封裝,以獲得更好的代碼可讀性和模塊化代碼。
PyTorch 提供了兩個基本方法 torch.utils.data.DataLoader
和torch.utils.data.Dataset
可以讓你預加載數據集或者你的數據。
Dataset
存儲樣本及其相關的標籤, DataLoader
封裝了關於 Dataset
的迭代器,讓我們可以方便地讀取樣本。
PyTorch庫中也提供了一些常用的數據集可以方便用户做預加載可以通過torch.utils.data.Dataset
調用,還提供了一些對應數據集的方法。它們可以用於模型的原型和基準測試。
詳細可以戳這裏:
加載數據集
接下來我們看一下怎麼從TorchVision加載Fashion-MNIST數據集。
Fashion-MNIST是Zalando的一個數據集,包含6萬個訓練樣例和1萬個測試樣例。
每個樣例由兩部分組成,一個28×28灰度圖像和一個十分類標籤中的某一個標籤。
我們要加載 FashionMNIST Dataset需要用到以下幾個參數:
- root
數據集的存儲地址
- train
指定你要取訓練集還是測試集
- download=True
如果你指定的 root
中沒有數據集,會自動從網上下載數據集
- transform
、 target_transform
指定特徵和標籤轉換
下邊這段代碼是取FashionMNIST的訓練集和測試集,root設置了一個data文件,運行下邊這段代碼以後你可以看到當前目錄下邊應該多了一個data文件夾,裏邊就是FashionMNIST數據集文件了。
```py import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() )
test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() ) ```
迭代和可視化數據集
我們可以像列表索引一樣查看Datasets
。
可以使用matplotlib
可視化我們的數據集。
其他代碼解析看註釋。
至於畫子圖有兩個方法,二者的區別僅在於一個面向方法,一個面向對象,別的完全一樣。
-
subplot ```py figure = plt.figure() cols, rows = 3, 3 for i in range(1, cols * rows + 1): plt.subplot(rows, cols, i)
plt.show() ```
-
add_subplot ```py figure = plt.figure() cols, rows = 3, 3 for i in range(1, cols * rows + 1): figure.subplot(rows, cols, i)
plt.show() ```
py
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item() # 從數據集中隨機採樣
img, label = training_data[sample_idx] # 取得數據集的圖和標籤
figure.add_subplot(rows, cols, i) # 畫子圖,也可以plt.subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray") # 是黑白圖,這裏做一個維度壓縮,把1通道的1壓縮掉
plt.show()
最後隨機採樣的結果大概是這樣的:
使用DataLoader
Dataset
可以檢索我們數據集中一個樣本的特徵和標籤。但是在訓練模型的時候,我們通常希望數據以小批量(minibatch)的方式作為輸入,在每個epoch中重新調整數據以防止過擬合,並且還能使用Python的multiprocessing
加速數據檢索。
DataLoader
是一個迭代器,將剛才提到的複雜方法抽象成簡單的API。
```py from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True) ```
通過DataLoader迭代獲取數據
我們已經將數據集加載到DataLoader
中,並可以根據需要迭代數據集。
下面的每次迭代返回一個批量數據的train_features
和train_labels
(分別包含batch_size=64
個特徵和標籤)。
因為我們指定了shuffle=True
,在遍歷所有批量之後,數據會被打亂(要對數據加載順序進行更細粒度的控制,戳這裏https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler 。
```py
Display image and label.
train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}") ```
為你的數據創建自定義數據集
自定義Dataset類必須實現三個函數:__init__
, __len__
和__getitem__
。看看這個FashionMNIST圖像存儲在img_dir目錄中,它們的標籤單獨存儲在CSV文件annotations_file中。
在下一節我們詳細分析一下每個函數中發生的事情。
```py import os import pandas as pd from torchvision.io import read_image
class CustomImageDataset(Dataset): def init(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
```
init
__init__
函數在實例化Dataset對象時運行一次,幫我們初始化一個目錄,其中包含圖像、註釋文件和兩個變換(下一節將詳細介紹)。
The labels.csv file looks like:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
py
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
len
__len__
方法返回我們數據集中的樣本數量。
py
def __len__(self):
return len(self.img_labels)
getitem
__getitem__
函數當你給定一個索引idx
的時候,用於加載並返回樣本。
基於索引,該函數去尋找圖像在磁盤上的位置,使用read_image
將其轉換為一個張量,從self
中的csv數據中檢索相應的標籤img_labels
,調用它們上的變換函數(如果適用),並返回一個元組,元組中是圖像的張量和對應的標籤。
py
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label