EasyCV帶你復現更好更快的自監督演算法-FastConvMAE
作者: 夕陌、謙言、莫申童、臨在
導讀
自監督學習(Self-Supervised Learning)利用大量無標註的資料進行表徵學習,在特定下游任務上對引數進行微調,極大降低了影象任務繁重的標註工作,節省大量人力成本。近年來,自監督學習在視覺領域大放異彩,受到了越來越多的關注。在CV領域湧現瞭如SIMCLR、MOCO、SwAV、DINO、MoBY、MAE等一系列工作。其中MAE的表現尤為驚豔,大家都被MAE簡潔高效的效能所吸引,紛紛在 MAE上進行改進,例如MixMIM,VideoMAE等工作。MAE詳解請參考往期文章:MAE自監督演算法介紹和基於EasyCV的復現 。
ConvMAE是由上海人工智慧實驗室和mmlab聯合發表在NeurIPS2022的一項工作,與MAE相比,訓練相同的epoch數, ImageNet-1K 資料集的finetune準確率提高了 1.4%,COCO2017資料集上微調 25 個 epoch相比微調100 個 epoch 的 MAE AP box提升2.9, AP mask提升2.2, 語義分割任務上相比MAE mIOU提升3.6%。在此基礎上,作者提出了FastConvMAE,進一步優化了訓練效能,僅預訓練50個epoch,ImageNet Finetuning的精度就超過MAE預訓練1600個epoch的精度0.77個點(83.6/84.37)。在檢測任務上,精度也超過ViTDet和Swin。
EasyCV是阿里巴巴開源的基於Pytorch,以自監督學習和Transformer技術為核心的 all-in-one 視覺演算法建模工具,覆蓋主流的視覺建模任務例如影象分類,度量學習,目標檢測,例項/語音/全景分割、關鍵點檢測等領域,具有較強的易用性和擴充套件性,同時注重效能調優,旨在為社群帶來更多更快更強的演算法。
近期FastConvMAE工作在EasyCV框架內首次對外開源,本文將重點介紹ConvMAE和FastConvMAE的主要工作,以及對應的程式碼實現,最後提供詳細的教程示例如何進行FastConvMAE的預訓練和下游任務的finetune。
ConvMAE
ConvMAE是由上海人工智慧實驗室和mmlab聯合發表在NeurIPS2022裡的一項工作,ConvMAE的提出證明了使用區域性歸納偏置和多尺度的金字塔結構,通過MAE的訓練方式可以學習到更好的特徵表示。該工作提出:
- 使用block-wise mask策略來確保計算效率。
- 輸出編碼器的多尺度特徵,同時捕獲細粒度和粗粒度影象資訊。
原文參考:http://arxiv.org/abs/2205.03892
實驗結果顯示,上述兩項策略是簡潔而有效的,使得ConvMAE在多個視覺任務中相比MAE獲得了明顯提升。以ConvMAE-Base和MAE-Base相比為例:在影象分類任務上, ImageNet-1K 資料集的微調準確率提高了 1.4%;在目標檢測任務上,COCO2017微調 25 個 epoch 的AP box達到53.2%,AP mask達到47.1%,與微調100 個 epoch 的 MAE-Base相比分別提升2.9% 和 2.2% ;在語義分割任務上,使用UperNet網路頭,ConvMAE-Base在ADE20K上的mIoU達到51.7%,相比MAE-Base提升3.6%。
ConvMAE的總體流程
與MAE不同的是,ConvMAE的編碼器將輸入影象逐步抽象為多尺度token embedding,而解碼器則重建被mask掉的tokens對應的畫素。對於前面stage部分的高解析度token embedding,採用卷積塊對區域性進行編碼,對於後面的低解析度token embedding,則使用transformer來聚合全域性資訊。因此,ConvMAE的編碼器在不同階段可以同時獲得區域性和全域性資訊,並生成多尺度特徵。
當前的masked auto encoding框架,如BEiT,SimMIM,所採用的mask策略不能直接用於ConvMAE,因為在後面的transformer階段,所有的tokens都需要保留。這導致對大模型進行預訓練的計算成本過高,失去了MAE在transformer編碼器中省去masked tokens的效率優勢。此外,直接使用convolution-transformer結構的編碼器進行預訓練會導致卷積部分因為隨機的mask而造成預訓練的資訊洩露,因而也會降低預訓練所得模型的質量。
針對這些問題,ConvMAE提出了混合convolution-transformer架構。ConvMAE採用分塊mask策略 (block-wise masking strategy):,首先隨機在後期的獲取transformer token中生成後期的mask,然後對mask固定位置逐步進行上取樣到早期卷積階段的高解析度。這樣,後期處理的token可以完全分離為masked tokens和visible tokens,從而並繼承了MAE使用稀疏encoder的計算效率。
下面將分別針對encoder、mask策略以及decoder部分展開介紹。
Encoder
如總體流程圖所示,encoder包括3 個階段,每個階段輸出的特徵維度分別是:H/4 × W/4, H/8 × W/8, H/16 × W/16,其中H × W為輸入影象解析度。前兩個是卷積階段,使用卷積模組將輸入轉換為token embeddings E1 ∈ R^(H/4 × W/4 ×C1) and E2 ∈ R^(H/8 × W/8 ×C2) 。其中卷積模組用5 × 5的卷積代替self-attention操作。前兩個階段的感受野較小主要捕捉影象的區域性特徵,第三個階段使用transformer模組,將粗粒度特徵融合, 並將感受野擴充套件到整個影象,獲得token embeddings E3 ∈ R(H/16 × W/16 ×C3)。在每個階段之間,使用stride為2的卷積對tokens進行下采樣。
其他包含transformer的結構,如CPT、Container、Uniformer、CMT、Swin等,在第一階段的輸入用相對位置編碼或零填充卷積替代絕對位置編碼,而作者發現在第3個transformer stage中使用絕對位置編碼可獲得最優效能。class token也從編碼器中移除。
Mask策略
MAE、BEiT等,對輸入patch採用隨機mask。但同樣的策略不能直接應用於ConvMAE編碼器:如果獨立地從stage-1的H/4 × W/4個tokens中隨機抽取mask,將導致降取樣後的stage-3的幾乎所有token都有部分可見資訊,使得編碼器不再稀疏。因此作者提出,從stage-3的輸入tokens中以同樣比例 (例如75%)生成mask,再對mask上取樣2倍和4倍,分別作為stage-2和stage-1的mask。這樣,ConvMAE在3個階段都只含有很少的(例如25%)可見token,從而使得預訓練時編碼器的效率不受影響。而解碼器的任務e則保持相同,即重建編碼過程中被mask掉的tokens。
同時,前2個階段的5X5卷積操作會在masked patches的邊緣處洩漏不可見token的重建答案。為了避免這種情況保證預訓練的質量,作者在前兩個階段採用了masked convolution, 使被mask掉的區域不參與編碼過程。
Decoder
原始MAE的decoder的輸入以編碼器的輸出和mask掉的tokens作為輸入,然後通過堆疊的transformer blocks進行影象重建。ConvMAE編碼器獲得多尺度特徵E1、E2、E3,同時捕獲細粒度和粗粒度影象資訊。為了更好地的預訓練,作者通過stride-4和stride-2卷積將E1和E2下采樣到E3的相同大小,並進行多尺度特徵融合,再通過一個linear層得到最終要輸入給 decoder 的可見token。目標函式和MAE相同,僅採用MSE作為損失函式,計算預測向量和被mask掉畫素值之前的MSE loss,即只考慮mask掉的patches的重建。
下游任務
預訓練之後,ConvMAE可以輸出多尺度的特徵用於檢測分割任務。
檢測任務中,先將第stage-3的輸出特徵E3通過2x2最大池化獲得E4。由於ConvMAE stage-3有11個self-attention層(ConvMAE-base),計算成本過高,作者參考ViT的benchmark將stage-3中除第1、4、7、11之外的所有global self-attention layers替換為了Window size7×7 的 local self-attention 層。修改後的local self-attention仍然由預訓練的global self-attention進行初始化。global transformer blocks之間共享global relative position bias,local transformer blocks之間共享local relative position bias,這樣就大大減輕了stage-3的計算和GPU記憶體開銷。然後將多尺度特徵E1、E2、E3、E4送入MaskRCNN head進行目標檢測。
而分割任務保留了stage-3的結構。
Benchmark
影象分類
ConvMAE基於ImageNet-1K,mask掉25%的input token做預訓練,Decoder部分是一個8層的transformer,embedding 維度是512,head是12個。預訓練引數和分類finetuning結果如下:
BEiT預訓練300個epoch,finetune的精度達到83.0%,linear-prob的精度是37.6%。與BEiT相比,ConVMAE僅需要25%的token和一個輕量級的decoder finetune可達到85%,linear-prob可以達到70.9%。與原來的MAE相比,預訓練相同的1600個epoch,ConVMAE比MAE提升1.4個點。與SimMIM(backbone使用Swin-B)相比提升了1個點。
檢測
作者用ConvMAE替換Mask-RCNN的backbone,載入ConvMAE的預訓練模型訓練COCO資料集。
與ViT在COCO資料集上finetune100個epoch的結果相比,ConVMAE僅finetune 25個epoch在APbox和APmask就提升了2.9和2.2個點。
與ViTDet和MIMDet相比,ConvMAE finetune epoch更少、引數更少,分別超過了它們2.0%和1.7%。
與Swin和MViTv2相比,在APbox/APmask,其效能分別高出4.0%/3.6%和2.2%/1.4%。
分割
作者用ConvMAE替換UperNet的backbone,載入ConvMAE的預訓練模型訓練ADE20K資料集。
從結果中可以看出,相比與DeiT, Swin,MoCo-v3等網路ConvMAE取得了更高的效能(51.7%)。表明ConvMAE的多尺度特徵大大縮小了預訓練Backbone 和下游網路之間的傳輸差距。
Fast ConvMAE
ConvMAE雖然在分類、檢測、分割等下游任務中有了精度提升,並解決了pretraining-finetuning 的差異問題,但是模型的預訓練依然耗時,ConvMAE的結果中,模型預訓練了1600個epoch,因此作者又在ConvMAE的基礎之上做了進一步的效能優化,提出了Fast ConvMAE,FastConvMAE提出了mask互補和deocder融合的方案,來實現快速的mask建模方案,進一步縮短了預訓練的時間,從原來預訓練的1600epoch縮短到了50epoch。FastConvMAE的正式論文作者會在未來發出。
首先,FastConvMAE創新地設計出decoder互相融合的Mixture of Reconstructor (MoR),可以讓masked patches從不同的tokenizer中學習到互補的資訊,包括EMA 的self-ensembling性質,DINO的similarity-discrimination能力,以及CLIP的multimodal知識。MoR主要包括兩個部分,Partially-Shared Decoder(PS-Decoder)和Mixture of Tokenizer(MoT), PS-Decoder可以避免不同tokenizer的不同知識之間會產生梯度的衝突,MoT是用來生成不同的token作為masked patches的target。
同時Mask部分採用了互補策略,原來的mask每次只會保留例如25%的tokens,FastConvMAE將mask分成了4份,每一份都保留25%,4份mask之間互補。這樣,相當於1張圖片被分成了4張圖片進行學習,理論上達到了4倍的學習效果。
def random_masking(self, x, mask_ratio=None): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N = x.shape[0] L = self.num_patches len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep1 = ids_shuffle[:, :len_keep] ids_keep2 = ids_shuffle[:, len_keep:2 * len_keep] ids_keep3 = ids_shuffle[:, 2 * len_keep:3 * len_keep] ids_keep4 = ids_shuffle[:, 3 * len_keep:] # generate the binary mask: 0 is keep, 1 is remove mask1 = torch.ones([N, L], device=x.device) mask1[:, :len_keep] = 0 # unshuffle to get the binary mask mask1 = torch.gather(mask1, dim=1, index=ids_restore) mask2 = torch.ones([N, L], device=x.device) mask2[:, len_keep:2 * len_keep] = 0 # unshuffle to get the binary mask mask2 = torch.gather(mask2, dim=1, index=ids_restore) mask3 = torch.ones([N, L], device=x.device) mask3[:, 2 * len_keep:3 * len_keep] = 0 # unshuffle to get the binary mask mask3 = torch.gather(mask3, dim=1, index=ids_restore) mask4 = torch.ones([N, L], device=x.device) mask4[:, 3 * len_keep:4 * len_keep] = 0 # unshuffle to get the binary mask mask4 = torch.gather(mask4, dim=1, index=ids_restore) return [ids_keep1, ids_keep2, ids_keep3, ids_keep4], [mask1, mask2, mask3, mask4], ids_restore
前兩個卷積階段將輸入轉換為embeddings tokens E1和E2。然後E1和E2分別從4份mask中獲取4份可見的tokens並進行拼接,作為decoder的輸入,Decoder處理的是拼接後的tokens。程式碼參考如下:
def encoder_forward(self, x, mask_ratio): # embed patches ids_keep, masks, ids_restore = self.random_masking(x, mask_ratio) mask_for_patch1 = [ 1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat( 1, 1, 1, 16).reshape(-1, 14, 14, 4, 4).permute( 0, 1, 3, 2, 4).reshape(x.shape[0], 56, 56).unsqueeze(1) for mask in masks ] mask_for_patch2 = [ 1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat( 1, 1, 1, 4).reshape(-1, 14, 14, 2, 2).permute( 0, 1, 3, 2, 4).reshape(x.shape[0], 28, 28).unsqueeze(1) for mask in masks ] s1 = self.patch_embed1(x) s1 = self.pos_drop(s1) for blk in self.blocks1: s1 = blk(s1, mask_for_patch1) s2 = self.patch_embed2(s1) for blk in self.blocks2: s2 = blk(s2, mask_for_patch2) stage1_embed = self.stage1_output_decode(s1).flatten(2).permute(0, 2, 1) stage2_embed = self.stage2_output_decode(s2).flatten(2).permute(0, 2, 1) stage1_embed_1 = torch.gather( stage1_embed, dim=1, index=ids_keep[0].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_1 = torch.gather( stage2_embed, dim=1, index=ids_keep[0].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed_2 = torch.gather( stage1_embed, dim=1, index=ids_keep[1].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_2 = torch.gather( stage2_embed, dim=1, index=ids_keep[1].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed_3 = torch.gather( stage1_embed, dim=1, index=ids_keep[2].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_3 = torch.gather( stage2_embed, dim=1, index=ids_keep[2].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed_4 = torch.gather( stage1_embed, dim=1, index=ids_keep[3].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_4 = torch.gather( stage2_embed, dim=1, index=ids_keep[3].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed = torch.cat([ stage1_embed_1, stage1_embed_2, stage1_embed_3, stage1_embed_4 ]) stage2_embed = torch.cat([ stage2_embed_1, stage2_embed_2, stage2_embed_3, stage2_embed_4 ]) x = self.patch_embed3(s2) x = x.flatten(2).permute(0, 2, 1) x = self.patch_embed4(x) # add pos embed w/o cls token x = x + self.pos_embed x1 = torch.gather(x, dim=1, index=ids_keep[0].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x2 = torch.gather(x, dim=1, index=ids_keep[1].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x3 = torch.gather(x, dim=1, index=ids_keep[2].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x4 = torch.gather(x, dim=1, index=ids_keep[3].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x = torch.cat([x1, x2, x3, x4]) # apply Transformer blocks for blk in self.blocks3: x = blk(x) x = x + stage1_embed + stage2_embed x = self.norm(x) mask = torch.cat([masks[0], masks[1], masks[2], masks[3]]) return x, mask, ids_restore
Benchmark
EasyCV復現的結果如下:
ImageNet Pretrained
Config |
Epochs |
Download |
50 |
ImageNet Finetuning
Algorithm |
Fintune Config |
Pretrained Config |
Top-1 |
Download |
Fast ConvMAE(EasyCV) |
84.4% |
- log |
||
Fast ConvMAE(官方) |
|
84.4% |
Object Detection
Algorithm |
Eval Config |
Pretrained Config |
mAP (Box) |
mAP (Mask) |
Download |
Fast ConvMAE(EasyCV) |
51.3% |
45.6% |
|||
Fast ConvMAE(官方) |
51.0% |
45.4% |
從結果可以看出,僅預訓練50個epoch,ImageNet Finetuning的精度就超過MAE預訓練1600個epoch的精度0.77個點(83.6/84.37)。在檢測任務上,精度也超過ViTDet和Swin。
FastConvMAE的更多官方結果請參考:http://github.com/Alpha-VL/FastConvMAE 。
Tutorial
一、安裝依賴包
如果是在本地開發環境執行,可以參考該連結安裝環境。若使用PAI-DSW進行實驗則無需安裝相關依賴,在PAI-DSW docker中已內建相關環境。
二、資料準備
資料準備請參考文件:http://github.com/alibaba/EasyCV/blob/master/docs/source/prepare_data.md
三、模型預訓練
FastConvMAE佔用視訊記憶體較大,建議使用A100資源。(FastConvMAE一次forward-backward等價於ConvMAE forward-backward 4次)
在EasyCV中,使用配置檔案的形式來實現對模型引數、資料輸入及增廣方式、訓練策略的配置,僅通過修改配置檔案中的引數設定,就可以完成實驗配置進行訓練。
配置EasyCV路徑
# 檢視easycv安裝位置 import easycv print(easycv.__file__)
$ export PYTHONPATH=$PYTHONPATH:${your EasyCV root path}
訓練
$ python -m torch.distributed.launch --nproc_per_node=8 --master_port=29930 \ tools/train.py \ configs/selfsup/fast_convmae/fast_convmae_vit_base_patch16_8xb64_50e.py \ --work_dir ./work_dir \ --launcher pytorch
下游任務finetune
下載預訓練模型
$ wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/FastConvMAE/pretrained/epoch_50.pth
- 單卡
$ python tools/train.py \ ${CONFIG_FILE} \ --work_dir ./work_dir \ --load_from=./epoch_50.pth
- 多卡
$ python -m torch.distributed.launch --nproc_per_node=8 --master_port=29930 \ tools/train.py \ ${CONFIG_FILE} \ --work_dir ./work_dir \ --launcher pytorch \ --load_from=./epoch_50.pth
分類任務 CONFIG_FILE 請參考:http://github.com/alibaba/EasyCV/tree/master/benchmarks/selfsup/classification/imagenet/fast_convmae_vit_base_patch16_8xb64_100e_fintune.py
分類任務 CONFIG_FILE 請參考:http://github.com/alibaba/EasyCV/blob/master/benchmarks/selfsup/detection/coco/mask_rcnn_conv_vitdet_50e_coco.py
Reference
EasyCV:http://github.com/alibaba/EasyCV/blob/master/easycv/models/backbones/conv_mae_vit.py
EasyCV往期分享
- DeepRec 大規模稀疏模型訓練推理引擎
- 跨模態學習能力再升級,EasyNLP電商文圖檢索效果重新整理SOTA
- DeepRec 大規模稀疏模型訓練推理引擎
- 基於EMR的新一代資料湖儲存加速技術詳解
- EasyNLP帶你實現中英文機器閱讀理解
- 最高增強至1440p,阿里雲釋出端側實時超分工具,低成本實現高畫質
- EasyNLP帶你實現中英文機器閱讀理解
- 跨模態學習能力再升級,EasyNLP電商文圖檢索效果重新整理SOTA
- EasyCV帶你復現更好更快的自監督演算法-FastConvMAE
- SREWorks前端低程式碼工程設計概覽
- 阿里雲大資料助力知衣科技打造AI服裝行業核心競爭力
- EasyNLP玩轉文字摘要(新聞標題)生成
- 資料湖管理及優化
- EMR重磅釋出智慧運維診斷系統(EMR Doctor)——開源大資料平臺運維利器
- 資料湖統一元資料與許可權
- 超長序列,超快預測!深勢科技聯手阿里雲,AI 蛋白質預測再下一城
- 中文稀疏GPT大模型落地 — 通往低成本&高效能多工通用自然語言理解的關鍵里程碑
- YOLOX-PAI: 加速 YOLOX, 比 YOLOV6 更快更強
- Kubernetes資源編排系列之五: OAM篇
- 動態尺寸模型優化實踐之Shape Constraint IR Part II