使用PyTorch來進展不平衡資料集的影象分類

語言: CN / TW / HK

導讀

一個非常簡單和容易上手的例子。

圖片

對於教程中使用的大多數人工資料集,每個類都有相同數量的資料。然而,在實際應用中,這種情況很少發生。今天,我將給你介紹來自Kaggle的木薯葉分類,並告訴你當類頻率有很大差異時該怎麼做。

處理類別的不平衡

有兩種方法可以解決這個問題。

  • WeightedRandomSampler
  • loss函式中的weight引數

下一步是建立一個有5個方法的CassavaClassifier類:load_data()、load_model()、fit_one_epoch()、val_one_epoch()和fit()。

在load_data()中,將構造一個train和驗證資料集,並返回資料載入器以供進一步使用。

在load_model()中定義了體系結構、損失函式和優化器。

fit方法包含一些初始化和對fit_one_epoch()和val_one_epoch()的迴圈。

早期停止

早期停止類有助於根據驗證損失跟蹤最佳模型,並儲存檢查點。

```

Callbacks

Early stopping

class EarlyStopping:   def init(self, patience=1, delta=0, path='checkpoint.pt'):     self.patience = patience     self.delta = delta     self.path= path     self.counter = 0     self.best_score = None     self.early_stop = False

def call(self, val_loss, model):     if self.best_score is None:       self.best_score = val_loss       self.save_checkpoint(model)     elif val_loss > self.best_score:       self.counter +=1       if self.counter >= self.patience:         self.early_stop = True      else:       self.best_score = val_loss       self.save_checkpoint(model)       self.counter = 0      

def save_checkpoint(self, model):     torch.save(model.state_dict(), self.path) ```

Init

我們首先初始化CassavaClassifier類。

class CassavaClassifier():     def __init__(self, data_dir, num_classes, device, Transform=None, sample=False, loss_weights=False, batch_size=16,      lr=1e-4, stop_early=True, freeze_backbone=True):     #############################################################################################################     # data_dir - directory with images in subfolders, subfolders name are categories     # Transform - data augmentations     # sample - if the dataset is imbalanced set to true and RandomWeightedSampler will be used     # loss_weights - if the dataset is imbalanced set to true and weight parameter will be passed to loss function     # freeze_backbone - if using pretrained architecture freeze all but the classification layer     ###############################################################################################################         self.data_dir = data_dir         self.num_classes = num_classes         self.device = device         self.sample = sample         self.loss_weights = loss_weights         self.batch_size = batch_size         self.lr = lr         self.stop_early = stop_early         self.freeze_backbone = freeze_backbone         self.Transform = Transform

Load Data

訓練影象被組織在子資料夾中,子資料夾名稱表示影象的類。這是影象分類問題的典型情況,幸運的是,不需要編寫自定義資料集類。在這種情況下,可以立即使用torchvision中的ImageFolder。如果你想使用WeightedRandomSampler,你需要為資料集的每個元素指定一個權重。通常,總影象總比上類別數被用作一個權重。

``` def load_data(self):     train_full = torchvision.datasets.ImageFolder(self.data_dir, transform=self.Transform)     train_set, val_set = random_split(train_full, [math.floor(len(train_full)0.8), math.ceil(len(train_full)0.2)])

self.train_classes = [label for _, label in train_set]     if self.sample:         # Need to get weight for every image in the dataset         class_count = Counter(self.train_classes)         class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])          # Can't iterate over class_count because dictionary is unordered

sample_weights = [0] * len(train_set)         for idx, (image, label) in enumerate(train_set):             class_weight = class_weights[label]             sample_weights[idx] = class_weight

sampler = WeightedRandomSampler(weights=sample_weights,                                         num_samples = len(train_set), replacement=True)           train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=sampler)     else:         train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)

val_loader = DataLoader(val_set, batch_size=self.batch_size)

return train_loader, val_loader ```

Load Model

在該方法中,我使用遷移學習,架構引數從預先訓練的resnet50和efficientnet-b7中選擇。CrossEntropyLoss和許多其他損失函式都有權重引數。這是一個手動調整引數,用於處理不平衡。在這種情況下,不需要為每個引數定義權重,只需為每個類定義權重。

``` def load_model(self, arch='resnet'):     ##############################################################################################################     # arch - choose the pretrained architecture from resnet or efficientnetb7     ##############################################################################################################      if arch == 'resnet':         self.model = torchvision.models.resnet50(pretrained=True)         if self.freeze_backbone:             for param in self.model.parameters():                 param.requires_grad = False         self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=self.num_classes)     elif arch == 'efficient-net':         self.model = EfficientNet.from_pretrained('efficientnet-b7')         if self.freeze_backbone:             for param in self.model.parameters():                 param.requires_grad = False         self.model._fc = nn.Linear(in_features=self.model._fc.in_features, out_features=self.num_classes)    

self.model = self.model.to(self.device)

self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) 

if self.loss_weights:         class_count = Counter(self.train_classes)         class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])         # Cant iterate over class_count because dictionary is unordered         class_weights = class_weights.to(self.device)           self.criterion = nn.CrossEntropyLoss(class_weights)     else:         self.criterion = nn.CrossEntropyLoss()  ```

Fit One Epoch

這個方法只包含一個經典的訓練迴圈,帶有訓練損失記錄和tqdm進度條。

``` def fit_one_epoch(self, train_loader, epoch, num_epochs ):      step_train = 0

train_losses = list() # Every epoch check average loss per batch      train_acc = list()     self.model.train()     for i, (images, targets) in enumerate(tqdm(train_loader)):         images = images.to(self.device)         targets = targets.to(self.device)

logits = self.model(images)         loss = self.criterion(logits, targets)

loss.backward()         self.optimizer.step()

self.optimizer.zero_grad()

train_losses.append(loss.item())

#Calculate running train accuracy         predictions = torch.argmax(logits, dim=1)         num_correct = sum(predictions.eq(targets))         running_train_acc = float(num_correct) / float(images.shape[0])         train_acc.append(running_train_acc)              train_loss = torch.tensor(train_losses).mean()         print(f'Epoch {epoch}/{num_epochs-1}')       print(f'Training loss: {train_loss:.2f}') ```

Validate one epoch

與上面類似,但此方法在驗證資料載入器上迭代。在每一個epoch'之後,平均batch損失和準確性被打印出來。

``` def val_one_epoch(self, val_loader, scaler):         val_losses = list()         val_accs = list()         self.model.eval()         step_val = 0         with torch.no_grad():             for (images, targets) in val_loader:                 images = images.to(self.device)                 targets = targets.to(self.device)

logits = self.model(images)                 loss = self.criterion(logits, targets)                 val_losses.append(loss.item())                                    predictions = torch.argmax(logits, dim=1)                 num_correct = sum(predictions.eq(targets))                 running_val_acc = float(num_correct) / float(images.shape[0])

val_accs.append(running_val_acc)             

self.val_loss = torch.tensor(val_losses).mean()             val_acc = torch.tensor(val_accs).mean() # Average acc per batch                      print(f'Validation loss: {self.val_loss:.2f}')               print(f'Validation accuracy: {val_acc:.2f}')  ```

Fit

Fit方法在訓練和驗證過程中經歷了許多階段和迴圈。如果預訓練模型的引數在開始時被凍結,那麼unfreeze_after定義了整個模型在多少個epoch之後開始訓練。在此之前,只訓練全連線層(分類器)。

def fit(self, train_loader, val_loader, num_epochs=10, unfreeze_after=5, checkpoint_dir='checkpoint.pt'):     if self.stop_early:         early_stopping = EarlyStopping(         patience=5,          path=checkpoint_dir)        for epoch in range(num_epochs):         if self.freeze_backbone:             if epoch == unfreeze_after:  # Unfreeze after x epochs                 for param in self.model.parameters():                     param.requires_grad = True         self.fit_one_epoch(train_loader, scaler, epoch, num_epochs)         self.val_one_epoch(val_loader, scaler)         if self.stop_early:             early_stopping(self.val_loss, self.model)             if early_stopping.early_stop:                 print('Early Stopping')                 print(f'Best validation loss: {early_stopping.best_score}')                 break

Run

現在,可以初始化CassavaClassifier類、建立dataloaders、設定模型並執行整個過程了。

``` Transform = T.Compose(                     [T.ToTensor(),                     T.Resize((256, 256)),                     T.RandomRotation(90),                     T.RandomHorizontalFlip(p=0.5),                     T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') data_dir = "Data/cassava-disease/train/train"

classifier = CassavaClassifier(data_dir=data_dir, num_classes=5, device=device, sample=True, Transform=Transform) train_loader, val_loader = classifier.load_data() classifier.load_model() classifier.fit(num_epochs=20, unfreeze_after=5, train_loader=train_loader, val_loader=val_loader) ```

Inference

使用ImageFolder載入測試資料是不可能的,因為顯然沒有帶有類的子資料夾。因此,我建立了一個返回影象和影象id的自定義資料集。隨後,載入模型檢查點,通過推理迴圈執行它,並將預測儲存到資料幀中。將資料幀匯出為CSV並提交結果。

```

Inference

model = torchvision.models.resnet50()

model = EfficientNet.from_name('efficientnet-b7')

model.fc = nn.Linear(in_features=model.fc.in_features, out_features=5) model = model.to(device) checkpoint = torch.load('Data/cassava-disease/sampler_checkpoint.pt') model.load_state_dict(checkpoint) model.eval()

Dataset for test data

class Cassava_Test(Dataset):   def init(self, dir, transform=None):     self.dir = dir     self.transform = transform

self.images = os.listdir(self.dir)  

def len(self):     return len(self.images)

def getitem(self, idx):     img = Image.open(os.path.join(self.dir, self.images[idx]))     return self.transform(img), self.images[idx] 

test_dir = 'Data/cassava-disease/test/test/0' test_set = Cassava_Test(test_dir, transform=Transform) test_loader = DataLoader(test_set, batch_size=4)  

Test loop

sub = pd.DataFrame(columns=['category', 'id']) id_list = [] pred_list = []

model = model.to(device)

with torch.no_grad():   for (image, image_id) in test_loader:     image = image.to(device)

logits = model(image)     predicted = list(torch.argmax(logits, 1).cpu().numpy())

for id in image_id:       id_list.append(id)        for prediction in predicted:       pred_list.append(prediction) sub['category'] = pred_list sub['id'] = id_list

mapping = {0:'cbb', 1:'cbsd', 2:'cgm', 3:'cmd', 4:'healthy'}

sub['category'] = sub['category'].map(mapping) sub = sub.sort_values(by='id')

sub.to_csv('Cassava_sub.csv', index=False) ```

如果在方案中包含WeightedRandomSampler或損失權值,則測試集的精度會提高2%。對於僅僅幾行程式碼來說,這是一個很好的改進。對於這個資料集,我沒有看到這兩種方法在精度上的巨大差異,但WeightedRandomSampler的表現要好一些。

不同的學習速度、優化器和資料擴充套件肯定有自己的發展空間。然而,對於這種簡單的方法來說,86%的準確率似乎足夠好了。

英文原文:https://marekpaulik.medium.com/imbalanced-dataset-image-classification-with-pytorch-6de864982eb1