0
本文作者: AI研習(xí)社-譯站 | 2020-11-22 09:31 |
譯者:AI研習(xí)社(季一帆)
雙語(yǔ)原文鏈接:Easy Self-Supervised Learning with BYOL
注:本文所有代碼可見(jiàn)Google Colab notebook,你可用Colab的免費(fèi)GPU運(yùn)行或改進(jìn)。
在深度學(xué)習(xí)中,經(jīng)常遇到的問(wèn)題是沒(méi)有足夠的標(biāo)記數(shù)據(jù),而手工標(biāo)記數(shù)據(jù)耗費(fèi)大量時(shí)間且人工成本高昂?;诖?,自我監(jiān)督學(xué)習(xí)成為深度學(xué)習(xí)的研究熱點(diǎn),旨在從未標(biāo)記樣本中進(jìn)行學(xué)習(xí),以緩解數(shù)據(jù)標(biāo)注困難的問(wèn)題。子監(jiān)督學(xué)習(xí)的目標(biāo)很簡(jiǎn)單,即訓(xùn)練一個(gè)模型使得相似的樣本具有相似的表示,然而具體實(shí)現(xiàn)卻困難重重。經(jīng)過(guò)谷歌這樣的諸多先驅(qū)者若干年的研究,子監(jiān)督學(xué)習(xí)如今已取得一系列的進(jìn)步與發(fā)展。
在BYOL之前,多數(shù)自我監(jiān)督學(xué)習(xí)都可分為對(duì)比學(xué)習(xí)或生成學(xué)習(xí),其中,生成學(xué)習(xí)一般GAN建模完整的數(shù)據(jù)分布,計(jì)算成本較高,相比之下,對(duì)比學(xué)習(xí)方法就很少面臨這樣的問(wèn)題。對(duì)此,BYOL的作者這樣說(shuō)道:
通過(guò)對(duì)比方法,同一圖像不同視圖的表示更接近(正例),不同圖像視圖的表示相距較遠(yuǎn)(負(fù)例),通過(guò)這樣的方式減少表示的生成成本。
為了實(shí)現(xiàn)對(duì)比方法,我們必須將每個(gè)樣本與其他許多負(fù)例樣本進(jìn)行比較。然而這樣會(huì)使訓(xùn)練很不穩(wěn)定,同時(shí)會(huì)增大數(shù)據(jù)集的系統(tǒng)偏差。BYOL的作者顯然明白這點(diǎn):
對(duì)比方法對(duì)圖像增強(qiáng)的方式非常敏感。例如,當(dāng)消除圖像增強(qiáng)中的顏色失真時(shí),SimCLR表現(xiàn)不佳??赡艿脑蚴?,同一圖像的不同裁切一般會(huì)共享顏色直方圖,而不同圖像的顏色直方圖是不同的。因此,在對(duì)比任務(wù)中,可以通過(guò)關(guān)注顏色直方圖,使用隨機(jī)裁切方式實(shí)現(xiàn)圖像增強(qiáng),其結(jié)果表示幾乎無(wú)法保留顏色直方圖之外的信息。
不僅僅是顏色失真,其他類型的數(shù)據(jù)轉(zhuǎn)換也是如此。一般來(lái)說(shuō),對(duì)比訓(xùn)練對(duì)數(shù)據(jù)的系統(tǒng)偏差較為敏感。在機(jī)器學(xué)習(xí)中,數(shù)據(jù)偏差是一個(gè)廣泛存在的問(wèn)題(見(jiàn)facial recognition for women and minorities),這對(duì)對(duì)比方法來(lái)說(shuō)影響更大。不過(guò)好在BYOL不依賴負(fù)采樣,從而很好的避免了該問(wèn)題。
BYOL的目標(biāo)與對(duì)比學(xué)習(xí)相似,但一個(gè)很大的區(qū)別是,BYOL不關(guān)心不同樣本是否具有不同的表征(即對(duì)比學(xué)習(xí)中的對(duì)比部分),僅僅使相似的樣品表征類似??瓷先ニ坪鯚o(wú)關(guān)緊要,但這樣的設(shè)定會(huì)顯著改善模型訓(xùn)練效率和泛化能力:
由于不需要負(fù)采樣,BLOY有更高的訓(xùn)練效率。在訓(xùn)練中,每次遍歷只需對(duì)每個(gè)樣本采樣一次,而無(wú)需關(guān)注負(fù)樣本。
BLOY模型對(duì)訓(xùn)練數(shù)據(jù)的系統(tǒng)偏差不敏感,這意味著模型可以對(duì)未見(jiàn)樣本也有較好的適用性。
BYOL最小化樣本表征和該樣本變換之后的表征間的距離。其中,不同變換類型包括0:平移、旋轉(zhuǎn)、模糊、顏色反轉(zhuǎn)、顏色抖動(dòng)、高斯噪聲等(我在此以圖像操作來(lái)舉例說(shuō)明,但BYOL也可以處理其他數(shù)據(jù)類型)。至于是單一變換還是幾種不同類型的聯(lián)合變換,這取決于你自己,不過(guò)我一般會(huì)采用聯(lián)合變換。但有一點(diǎn)需要注意,如果你希望訓(xùn)練的模型能夠應(yīng)對(duì)某種變換,那么用該變換處理訓(xùn)練數(shù)據(jù)時(shí)必要的。
手把手教你編碼BYOL
首先是數(shù)據(jù)轉(zhuǎn)換增強(qiáng)的編碼。BYOL的作者定義了一組類似于SimCLR的特殊轉(zhuǎn)換:
import random from typing import Callable, Tuple from kornia import augmentation as aug from kornia import filters from kornia.geometry import transform as tf import torch from torch import nn, Tensor class RandomApply(nn.Module): def __init__(self, fn: Callable, p: float): super().__init__() self.fn = fn self.p = p def forward(self, x: Tensor) -> Tensor: return x if random.random() > self.p else self.fn(x) def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: return nn.Sequential( tf.Resize(size=image_size), RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), aug.RandomGrayscale(p=0.2), aug.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), aug.RandomResizedCrop(size=image_size), aug.Normalize( mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]), ), ) |
上述代碼通過(guò)Kornia實(shí)現(xiàn)數(shù)據(jù)轉(zhuǎn)換,這是一個(gè)基于 PyTorch 的可微分的計(jì)算機(jī)視覺(jué)開源庫(kù)。當(dāng)然,你可以用其他開源庫(kù)實(shí)現(xiàn)數(shù)據(jù)轉(zhuǎn)換擴(kuò)充,甚至是自己編寫。實(shí)際上,可微分性對(duì)BYOL而言并沒(méi)有那么必要。
接下來(lái),我們編寫編碼器模塊。該模塊負(fù)責(zé)從基本模型提取特征,并將這些特征投影到低維隱空間。具體的,我們通過(guò)wrapper類實(shí)現(xiàn)該模塊,這樣我們可以輕松將BYOL用于任何模型,無(wú)需將模型編碼到腳本。該類主要由兩部分組成:
特征抽取,獲取模型最后一層的輸出。
映射,非線性層,將輸出映射到更低維空間。
特征提取通過(guò)hooks實(shí)現(xiàn)(如果你不了解hooks,推薦閱讀我之前的介紹文章How to Use PyTorch Hooks)。除此之外,代碼其他部分很容易理解。
from typing import Union def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module: return nn.Sequential( nn.Linear(dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), nn.Linear(hidden_size, projection_size), ) class EncoderWrapper(nn.Module): def __init__( self, model: nn.Module, projection_size: int = 256, hidden_size: int = 4096, layer: Union[str, int] = -2, ): super().__init__() self.model = model self.projection_size = projection_size self.hidden_size = hidden_size self.layer = layer self._projector = None self._projector_dim = None self._encoded = torch.empty(0) self._register_hook() @property def projector(self): if self._projector is None: self._projector = mlp( self._projector_dim, self.projection_size, self.hidden_size ) return self._projector def _hook(self, _, __, output): output = output.flatten(start_dim=1) if self._projector_dim is None: self._projector_dim = output.shape[-1] self._encoded = self.projector(output) def _register_hook(self): if isinstance(self.layer, str): layer = dict([*self.model.named_modules()])[self.layer] else: layer = list(self.model.children())[self.layer] layer.register_forward_hook(self._hook) def forward(self, x: Tensor) -> Tensor: _ = self.model(x) return self._encoded |
BYOL包含兩個(gè)相同的編碼器網(wǎng)絡(luò)。第一個(gè)編碼器網(wǎng)絡(luò)的權(quán)重隨著每一訓(xùn)練批次進(jìn)行更新,而第二個(gè)網(wǎng)絡(luò)(稱為“目標(biāo)”網(wǎng)絡(luò))使用第一個(gè)編碼器權(quán)重均值進(jìn)行更新。在訓(xùn)練過(guò)程中,目標(biāo)網(wǎng)絡(luò)接收原始批次訓(xùn)練數(shù)據(jù),而另一個(gè)編碼器則接收相應(yīng)的轉(zhuǎn)換數(shù)據(jù)。兩個(gè)編碼器網(wǎng)絡(luò)會(huì)分別為相應(yīng)數(shù)據(jù)生成低維表示。然后,我們使用多層感知器預(yù)測(cè)目標(biāo)網(wǎng)絡(luò)的輸出,并最大化該預(yù)測(cè)與目標(biāo)網(wǎng)絡(luò)輸出之間的相似性。
圖源:Bootstrap Your Own Latent, Figure 2
也許有人會(huì)想,我們不是應(yīng)該直接比較數(shù)據(jù)轉(zhuǎn)換之前和之后的隱向量表征嗎?為什么還有設(shè)計(jì)多層感知機(jī)?假設(shè)沒(méi)有MLP層的話,網(wǎng)絡(luò)可以通過(guò)將權(quán)重降低到零方便的使所有圖像的表示相似化,可這樣模型并沒(méi)有學(xué)到任何有用的東西,而MLP層可以識(shí)別出數(shù)據(jù)轉(zhuǎn)換并預(yù)測(cè)目標(biāo)隱向量。這樣避免了權(quán)重趨零,可以學(xué)習(xí)更恰當(dāng)?shù)臄?shù)據(jù)表示!
訓(xùn)練結(jié)束后,舍棄目標(biāo)網(wǎng)絡(luò)編碼器,只保留一個(gè)編碼器,根據(jù)該編碼器,所有訓(xùn)練數(shù)據(jù)可生成自洽表示。這正是BYOL能夠進(jìn)行自監(jiān)督學(xué)習(xí)的關(guān)鍵!因?yàn)閷W(xué)習(xí)到的表示具有自洽性,所以經(jīng)不同的數(shù)據(jù)變換后幾乎保持不變。這樣,模型使得相似示例的表示更加接近!
接下來(lái)編寫B(tài)YOL的訓(xùn)練代碼。我選擇使用Pythorch Lightning開源庫(kù),該庫(kù)基于PyTorch,對(duì)深度學(xué)習(xí)項(xiàng)目非常友好,能夠進(jìn)行多GPU培訓(xùn)、實(shí)驗(yàn)日志記錄、模型斷點(diǎn)檢查和混合精度訓(xùn)練等,甚至在cloud TPU上也支持基于該庫(kù)運(yùn)行PyTorch模型!
from copy import deepcopy from itertools import chain from typing import Dict, List import pytorch_lightning as pl from torch import optim import torch.nn.functional as f def normalized_mse(x: Tensor, y: Tensor) -> Tensor: x = f.normalize(x, dim=-1) y = f.normalize(y, dim=-1) return 2 - 2 * (x * y).sum(dim=-1) class BYOL(pl.LightningModule): def __init__( self, model: nn.Module, image_size: Tuple[int, int] = (128, 128), hidden_layer: Union[str, int] = -2, projection_size: int = 256, hidden_size: int = 4096, augment_fn: Callable = None, beta: float = 0.99, **hparams, ): super().__init__() self.augment = default_augmentation(image_size) if augment_fn is None else augment_fn self.beta = beta self.encoder = EncoderWrapper( model, projection_size, hidden_size, layer=hidden_layer ) self.predictor = nn.Linear(projection_size, projection_size, hidden_size) self.hparams = hparams self._target = None self.encoder(torch.zeros(2, 3, *image_size)) def forward(self, x: Tensor) -> Tensor: return self.predictor(self.encoder(x)) @property def target(self): if self._target is None: self._target = deepcopy(self.encoder) return self._target def update_target(self): for p, pt in zip(self.encoder.parameters(), self.target.parameters()): pt.data = self.beta * pt.data + (1 - self.beta) * p.data # --- Methods required for PyTorch Lightning only! --- def configure_optimizers(self): optimizer = getattr(optim, self.hparams.get("optimizer", "Adam")) lr = self.hparams.get("lr", 1e-4) weight_decay = self.hparams.get("weight_decay", 1e-6) return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay) def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x = batch[0] with torch.no_grad(): x1, x2 = self.augment(x), self.augment(x) pred1, pred2 = self.forward(x1), self.forward(x2) with torch.no_grad(): targ1, targ2 = self.target(x1), self.target(x2) loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) self.log("train_loss", loss.item()) return {"loss": loss} @torch.no_grad() def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x = batch[0] x1, x2 = self.augment(x), self.augment(x) pred1, pred2 = self.forward(x1), self.forward(x2) targ1, targ2 = self.target(x1), self.target(x2) loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) return {"loss": loss} @torch.no_grad() def validation_epoch_end(self, outputs: List[Dict]) -> Dict: val_loss = sum(x["loss"] for x in outputs) / len(outputs) self.log("val_loss", val_loss.item()) |
上述代碼部分源自Pythorch Lightning提供的示例代碼。這段代碼你尤其需要關(guān)注的是training_step,在此函數(shù)實(shí)現(xiàn)模型的數(shù)據(jù)轉(zhuǎn)換、特征投影和相似性損失計(jì)算等。
下文我們將在STL10數(shù)據(jù)集上對(duì)BYOL進(jìn)行實(shí)驗(yàn)驗(yàn)證。因?yàn)樵摂?shù)據(jù)集同時(shí)包含大量未標(biāo)記的圖像以及標(biāo)記的訓(xùn)練和測(cè)試集,非常適合無(wú)監(jiān)督和自監(jiān)督學(xué)習(xí)實(shí)驗(yàn)。STL10網(wǎng)站這樣描述該數(shù)據(jù)集:
STL-10數(shù)據(jù)集是一個(gè)用于研究無(wú)監(jiān)督特征學(xué)習(xí)、深度學(xué)習(xí)、自學(xué)習(xí)算法的圖像識(shí)別數(shù)據(jù)集。該數(shù)據(jù)集是對(duì)CIFAR-10數(shù)據(jù)集的改進(jìn),最明顯的便是,每個(gè)類的標(biāo)記訓(xùn)練數(shù)據(jù)比CIFAR-10中的要少,但在監(jiān)督訓(xùn)練之前,數(shù)據(jù)集提供大量的未標(biāo)記樣本訓(xùn)練模型學(xué)習(xí)圖像模型。因此,該數(shù)據(jù)集主要的挑戰(zhàn)是利用未標(biāo)記的數(shù)據(jù)(與標(biāo)記數(shù)據(jù)相似但分布不同)來(lái)構(gòu)建有用的先驗(yàn)知識(shí)。
通過(guò)Torchvision可以很方便的加載STL10,因此無(wú)需擔(dān)心數(shù)據(jù)的下載和預(yù)處理。
from torchvision.datasets import STL10 from torchvision.transforms import ToTensor TRAIN_DATASET = STL10(root="data", split="train", download=True, transform=ToTensor()) TRAIN_UNLABELED_DATASET = STL10( root="data", split="train+unlabeled", download=True, transform=ToTensor() ) TEST_DATASET = STL10(root="data", split="test", download=True, transform=ToTensor()) |
同時(shí),我們使用監(jiān)督學(xué)習(xí)方法作為基準(zhǔn)模型,以此衡量本文模型的準(zhǔn)確性?;€模型也可通過(guò)Lightning模塊輕易實(shí)現(xiàn):
class SupervisedLightningModule(pl.LightningModule): def __init__(self, model: nn.Module, **hparams): super().__init__() self.model = model def forward(self, x: Tensor) -> Tensor: return self.model(x) def configure_optimizers(self): optimizer = getattr(optim, self.hparams.get("optimizer", "Adam")) lr = self.hparams.get("lr", 1e-4) weight_decay = self.hparams.get("weight_decay", 1e-6) return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay) def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x, y = batch loss = f.cross_entropy(self.forward(x), y) self.log("train_loss", loss.item()) return {"loss": loss} @torch.no_grad() def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x, y = batch loss = f.cross_entropy(self.forward(x), y) return {"loss": loss} @torch.no_grad() def validation_epoch_end(self, outputs: List[Dict]) -> Dict: val_loss = sum(x["loss"] for x in outputs) / len(outputs) self.log("val_loss", val_loss.item()) |
可以看到,使用Pythorch Lightning可以方便的構(gòu)建并訓(xùn)練模型。只需為訓(xùn)練集和測(cè)試集創(chuàng)建DataLoader
對(duì)象,將其導(dǎo)入需要訓(xùn)練的模型即可。本實(shí)驗(yàn)中,epoch設(shè)置為25,學(xué)習(xí)率為1e-4。
from os import cpu_count from torch.utils.data import DataLoader from torchvision.models import resnet18 model = resnet18(pretrained=True) supervised = SupervisedLightningModule(model) trainer = pl.Trainer(max_epochs=25, gpus=-1, weights_summary=None) train_loader = DataLoader( TRAIN_DATASET, batch_size=128, shuffle=True, drop_last=True, ) val_loader = DataLoader( TEST_DATASET, batch_size=128, ) trainer.fit(supervised, train_loader, val_loader) |
接下來(lái),我們使用BYOL對(duì)ResNet18模型進(jìn)行預(yù)訓(xùn)練。在這次實(shí)驗(yàn)中,我選擇epoch為50,學(xué)習(xí)率依然是1e-4。注:該過(guò)程是本文代碼耗時(shí)最長(zhǎng)的部分,在K80 GPU的標(biāo)準(zhǔn)Colab中大約需要45分鐘。
model = resnet18(pretrained=True) byol = BYOL(model, image_size=(96, 96)) trainer = pl.Trainer( max_epochs=50, gpus=-1, accumulate_grad_batches=2048 // 128, weights_summary=None, ) train_loader = DataLoader( TRAIN_UNLABELED_DATASET, batch_size=128, shuffle=True, drop_last=True, ) trainer.fit(byol, train_loader, val_loader) |
然后,我們使用新的ResNet18模型重新進(jìn)行監(jiān)督學(xué)習(xí)。(為徹底清除BYOL中的前向hook,我們實(shí)例化一個(gè)新模型,在該模型引入經(jīng)過(guò)訓(xùn)練的狀態(tài)字典。)
# Extract the state dictionary, initialize a new ResNet18 model, # and load the state dictionary into the new model. # # This ensures that we remove all hooks from the previous model, # which are automatically implemented by BYOL. state_dict = model.state_dict() model = resnet18() model.load_state_dict(state_dict) supervised = SupervisedLightningModule(model) trainer = pl.Trainer( max_epochs=25, gpus=-1, weights_summary=None, ) train_loader = DataLoader( TRAIN_DATASET, batch_size=128, shuffle=True, drop_last=True, ) trainer.fit(supervised, train_loader, val_loader) |
通過(guò)這種方式,模型準(zhǔn)確率提高了約2.5%,達(dá)到了87.7%!雖然該方法需要更多的代碼(大約300行)以及一些庫(kù)的支撐,但相比其他自監(jiān)督方法仍顯得簡(jiǎn)潔。作為對(duì)比,可以看下官方的SimCLR或SwAV是多么復(fù)雜。而且,本文具有更快的訓(xùn)練速度,即使是Colab的免費(fèi)GPU,整個(gè)實(shí)驗(yàn)也不到一個(gè)小時(shí)。
本文要點(diǎn)總結(jié)如下。首先也是最重要的,BYOL是一種巧妙的自監(jiān)督學(xué)習(xí)方法,可以利用未標(biāo)記的數(shù)據(jù)來(lái)最大限度地提高模型性能。此外,由于所有ResNet模型都是使用ImageNet進(jìn)行預(yù)訓(xùn)練的,因此BYOL的性能優(yōu)于預(yù)訓(xùn)練的ResNet18。STL10是ImageNet的一個(gè)子集,所有圖像都從224x224像素縮小到96x96像素。雖然分辨率發(fā)生改變,我們希望自監(jiān)督學(xué)習(xí)能避免這樣的影響,表現(xiàn)出較好性能,而僅僅依靠STL10的小規(guī)模訓(xùn)練集是不夠的。
類似ResNet這樣的模型中,ML從業(yè)人員過(guò)于依賴預(yù)先訓(xùn)練的權(quán)重。雖然這在一定情況下是很好的選擇,但不一定適合其他數(shù)據(jù),哪怕在STL10這樣與ImageNet高度相似的數(shù)據(jù)中表現(xiàn)也不如人意。因此,我迫切希望將來(lái)在深度學(xué)習(xí)的研究中,自監(jiān)督方法能夠獲得更多的關(guān)注與實(shí)踐應(yīng)用。
https://arxiv.org/pdf/2006.07733.pdf
https://arxiv.org/pdf/2006.10029v2.pdf
https://github.com/fkodom/byol
https://github.com/lucidrains/byol-pytorch
https://github.com/google-research/simclr
https://cs.stanford.edu/~acoates/stl10/
AI研習(xí)社是AI學(xué)術(shù)青年和AI開發(fā)者技術(shù)交流的在線社區(qū)。我們與高校、學(xué)術(shù)機(jī)構(gòu)和產(chǎn)業(yè)界合作,通過(guò)提供學(xué)習(xí)、實(shí)戰(zhàn)和求職服務(wù),為AI學(xué)術(shù)青年和開發(fā)者的交流互助和職業(yè)發(fā)展打造一站式平臺(tái),致力成為中國(guó)最大的科技創(chuàng)新人才聚集地。
如果,你也是位熱愛(ài)分享的AI愛(ài)好者。歡迎與譯站一起,學(xué)習(xí)新知,分享成長(zhǎng)。
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見(jiàn)轉(zhuǎn)載須知。