0
本文作者: AI研習(xí)社 | 2017-07-10 11:33 |
雷鋒網(wǎng)按:本文作者廖星宇,原載于作者知乎專(zhuān)欄,雷鋒網(wǎng)經(jīng)授權(quán)發(fā)布。
自動(dòng)編碼器(AutoEncoder)最開(kāi)始作為一種數(shù)據(jù)的壓縮方法,其特點(diǎn)有:
跟數(shù)據(jù)相關(guān)程度很高,這意味著自動(dòng)編碼器只能壓縮與訓(xùn)練數(shù)據(jù)相似的數(shù)據(jù),這個(gè)其實(shí)比較顯然,因?yàn)槭褂蒙窠?jīng)網(wǎng)絡(luò)提取的特征一般是高度相關(guān)于原始的訓(xùn)練集,使用人臉訓(xùn)練出來(lái)的自動(dòng)編碼器在壓縮自然界動(dòng)物的圖片是表現(xiàn)就會(huì)比較差,因?yàn)樗粚W(xué)習(xí)到了人臉的特征,而沒(méi)有能夠?qū)W習(xí)到自然界圖片的特征;
壓縮后數(shù)據(jù)是有損的,這是因?yàn)樵诮稻S的過(guò)程中不可避免的要丟失掉信息;
到了2012年,人們發(fā)現(xiàn)在卷積網(wǎng)絡(luò)中使用自動(dòng)編碼器做逐層預(yù)訓(xùn)練可以訓(xùn)練更加深層的網(wǎng)絡(luò),但是很快人們發(fā)現(xiàn)良好的初始化策略要比費(fèi)勁的逐層預(yù)訓(xùn)練有效地多,2014年出現(xiàn)的Batch Normalization技術(shù)也是的更深的網(wǎng)絡(luò)能夠被被有效訓(xùn)練,到了15年底,通過(guò)殘差(ResNet)我們基本可以訓(xùn)練任意深度的神經(jīng)網(wǎng)絡(luò)。
所以現(xiàn)在自動(dòng)編碼器主要應(yīng)用有兩個(gè)方面,第一是數(shù)據(jù)去噪,第二是進(jìn)行可視化降維。然而自動(dòng)編碼器還有著一個(gè)功能就是生成數(shù)據(jù)。
我們之前講過(guò)GAN,它與GAN相比有著一些好處,同時(shí)也有著一些缺點(diǎn)。我們先來(lái)講講其跟GAN相比有著哪些優(yōu)點(diǎn)。
第一點(diǎn),我們使用GAN來(lái)生成圖片有個(gè)很不好的缺點(diǎn)就是我們生成圖片使用的隨機(jī)高斯噪聲,這意味著我們并不能生成任意我們指定類(lèi)型的圖片,也就是說(shuō)我們沒(méi)辦法決定使用哪種隨機(jī)噪聲能夠產(chǎn)生我們想要的圖片,除非我們能夠把初始分布全部試一遍。但是使用自動(dòng)編碼器我們就能夠通過(guò)輸出圖片的編碼過(guò)程得到這種類(lèi)型圖片的編碼之后的分布,相當(dāng)于我們是知道每種圖片對(duì)應(yīng)的噪聲分布,我們就能夠通過(guò)選擇特定的噪聲來(lái)生成我們想要生成的圖片。
第二點(diǎn),這既是生成網(wǎng)絡(luò)的優(yōu)點(diǎn)同時(shí)又有著一定的局限性,這就是生成網(wǎng)絡(luò)通過(guò)對(duì)抗過(guò)程來(lái)區(qū)分“真”的圖片和“假”的圖片,然而這樣得到的圖片只是盡可能像真的,但是這并不能保證圖片的內(nèi)容是我們想要的,換句話(huà)說(shuō),有可能生成網(wǎng)絡(luò)盡可能的去生成一些背景圖案使得其盡可能真,但是里面沒(méi)有實(shí)際的物體。
首先我們給出自動(dòng)編碼器的一般結(jié)構(gòu)
從上面的圖中,我們能夠看到兩個(gè)部分,第一個(gè)部分是編碼器(Encoder),第二個(gè)部分是解碼器(Decoder),編碼器和解碼器都可以是任意的模型,通常我們使用神經(jīng)網(wǎng)絡(luò)模型作為編碼器和解碼器。輸入的數(shù)據(jù)經(jīng)過(guò)神經(jīng)網(wǎng)絡(luò)降維到一個(gè)編碼(code),接著又通過(guò)另外一個(gè)神經(jīng)網(wǎng)絡(luò)去解碼得到一個(gè)與輸入原數(shù)據(jù)一模一樣的生成數(shù)據(jù),然后通過(guò)去比較這兩個(gè)數(shù)據(jù),最小化他們之間的差異來(lái)訓(xùn)練這個(gè)網(wǎng)絡(luò)中編碼器和解碼器的參數(shù)。當(dāng)這個(gè)過(guò)程訓(xùn)練完之后,我們可以拿出這個(gè)解碼器,隨機(jī)傳入一個(gè)編碼(code),希望通過(guò)解碼器能夠生成一個(gè)和原數(shù)據(jù)差不多的數(shù)據(jù),上面這種圖這個(gè)例子就是希望能夠生成一張差不多的圖片。
這件事情能不能實(shí)現(xiàn)呢?其實(shí)是可以的,下面我們會(huì)用PyTorch來(lái)簡(jiǎn)單的實(shí)現(xiàn)一個(gè)自動(dòng)編碼器。
首先我們構(gòu)建一個(gè)簡(jiǎn)單的多層感知器來(lái)實(shí)現(xiàn)一下。
class autoencoder(nn.Module):
def __init__(self):
super(autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(True),
nn.Linear(128, 64),
nn.ReLU(True),
nn.Linear(64, 12),
nn.ReLU(True),
nn.Linear(12, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(True),
nn.Linear(12, 64),
nn.ReLU(True),
nn.Linear(64, 128),
nn.ReLU(True),
nn.Linear(128, 28*28),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
這里我們定義了一個(gè)簡(jiǎn)單的4層網(wǎng)絡(luò)作為編碼器,中間使用ReLU激活函數(shù),最后輸出的維度是3維的,定義的解碼器,輸入三維的編碼,輸出一個(gè)28x28的圖像數(shù)據(jù),特別要注意最后使用的激活函數(shù)是Tanh,這個(gè)激活函數(shù)能夠?qū)⒆詈蟮妮敵鲛D(zhuǎn)換到-1 ~1之間,這是因?yàn)槲覀冚斎氲膱D片已經(jīng)變換到了-1~1之間了,這里的輸出必須和其對(duì)應(yīng)。
訓(xùn)練過(guò)程也比較簡(jiǎn)單,我們使用最小均方誤差來(lái)作為損失函數(shù),比較生成的圖片與原始圖片的每個(gè)像素點(diǎn)的差異。
同時(shí)我們也可以將多層感知器換成卷積神經(jīng)網(wǎng)絡(luò),這樣對(duì)圖片的特征提取有著更好的效果。
class autoencoder(nn.Module):
def __init__(self):
super(autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10
nn.ReLU(True),
nn.MaxPool2d(2, stride=2), # b, 16, 5, 5
nn.Conv2d(16, 8, 3, stride=2, padding=1), # b, 8, 3, 3
nn.ReLU(True),
nn.MaxPool2d(2, stride=1) # b, 8, 2, 2
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(8, 16, 3, stride=2), # b, 16, 5, 5
nn.ReLU(True),
nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15
nn.ReLU(True),
nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
這里使用了 nn.ConvTranspose2d(),這可以看作是卷積的反操作,可以在某種意義上看作是反卷積。
我們使用卷積網(wǎng)絡(luò)得到的最后生成的圖片效果會(huì)更好,具體的圖片效果我就不再這里放了,可以在我們的github上看到圖片的展示。github 地址:
變分編碼器是自動(dòng)編碼器的升級(jí)版本,其結(jié)構(gòu)跟自動(dòng)編碼器是類(lèi)似的,也由編碼器和解碼器構(gòu)成。
回憶一下我們?cè)谧詣?dòng)編碼器中所做的事,我們需要輸入一張圖片,然后將一張圖片編碼之后得到一個(gè)隱含向量,這比我們隨機(jī)取一個(gè)隨機(jī)噪聲更好,因?yàn)檫@包含著原圖片的信息,然后我們隱含向量解碼得到與原圖片對(duì)應(yīng)的照片。
但是這樣我們其實(shí)并不能任意生成圖片,因?yàn)槲覀儧](méi)有辦法自己去構(gòu)造隱藏向量,我們需要通過(guò)一張圖片輸入編碼我們才知道得到的隱含向量是什么,這時(shí)我們就可以通過(guò)變分自動(dòng)編碼器來(lái)解決這個(gè)問(wèn)題。
其實(shí)原理特別簡(jiǎn)單,只需要在編碼過(guò)程給它增加一些限制,迫使其生成的隱含向量能夠粗略的遵循一個(gè)標(biāo)準(zhǔn)正態(tài)分布,這就是其與一般的自動(dòng)編碼器最大的不同。
這樣我們生成一張新圖片就很簡(jiǎn)單了,我們只需要給它一個(gè)標(biāo)準(zhǔn)正態(tài)分布的隨機(jī)隱含向量,這樣通過(guò)解碼器就能夠生成我們想要的圖片,而不需要給它一張?jiān)紙D片先編碼。
在實(shí)際情況中,我們需要在模型的準(zhǔn)確率上與隱含向量服從標(biāo)準(zhǔn)正態(tài)分布之間做一個(gè)權(quán)衡,所謂模型的準(zhǔn)確率就是指解碼器生成的圖片與原圖片的相似程度。我們可以讓網(wǎng)絡(luò)自己來(lái)做這個(gè)決定,非常簡(jiǎn)單,我們只需要將這兩者都做一個(gè)loss,然后在將他們求和作為總的loss,這樣網(wǎng)絡(luò)就能夠自己選擇如何才能夠使得這個(gè)總的loss下降。另外我們要衡量?jī)煞N分布的相似程度,如何看過(guò)之前一片GAN的數(shù)學(xué)推導(dǎo),你就知道會(huì)有一個(gè)東西叫KL divergence來(lái)衡量?jī)煞N分布的相似程度,這里我們就是用KL divergence來(lái)表示隱含向量與標(biāo)準(zhǔn)正態(tài)分布之間差異的loss,另外一個(gè)loss仍然使用生成圖片與原圖片的均方誤差來(lái)表示。
我們可以給出KL divergence 的公式
這里變分編碼器使用了一個(gè)技巧“重新參數(shù)化”來(lái)解決 KL divergence 的計(jì)算問(wèn)題。
這時(shí)不再是每次產(chǎn)生一個(gè)隱含向量,而是生成兩個(gè)向量,一個(gè)表示均值,一個(gè)表示標(biāo)準(zhǔn)差,然后通過(guò)這兩個(gè)統(tǒng)計(jì)量來(lái)合成隱含向量,這也非常簡(jiǎn)單,用一個(gè)標(biāo)準(zhǔn)正態(tài)分布先乘上標(biāo)準(zhǔn)差再加上均值就行了,這里我們默認(rèn)編碼之后的隱含向量是服從一個(gè)正態(tài)分布的。這個(gè)時(shí)候我們是想讓均值盡可能接近0,標(biāo)準(zhǔn)差盡可能接近1。而論文里面有詳細(xì)的推導(dǎo)如何得到這個(gè)loss的計(jì)算公式,有興趣的同學(xué)可以去看看具體推到過(guò)程:
https://arxiv.org/pdf/1606.05908.pdf
下面是PyTorch的實(shí)現(xiàn):
reconstruction_function = nn.BCELoss(size_average=False) # mse loss
def loss_function(recon_x, x, mu, logvar):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
BCE = reconstruction_function(recon_x, x)
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return BCE + KLD
另外變分編碼器除了可以讓我們隨機(jī)生成隱含變量,還能夠提高網(wǎng)絡(luò)的泛化能力。
最后是VAE的代碼實(shí)現(xiàn):
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
if torch.cuda.is_available():
eps = torch.cuda.FloatTensor(std.size()).normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return self.decode(z), mu, logvar
VAE的結(jié)果比普通的自動(dòng)編碼器要好很多,下面是結(jié)果:
VAE的缺點(diǎn)也很明顯,他是直接計(jì)算生成圖片和原始圖片的均方誤差而不是像GAN那樣去對(duì)抗來(lái)學(xué)習(xí),這就使得生成的圖片會(huì)有點(diǎn)模糊?,F(xiàn)在已經(jīng)有一些工作是將VAE和GAN結(jié)合起來(lái),使用VAE的結(jié)構(gòu),但是使用對(duì)抗網(wǎng)絡(luò)來(lái)進(jìn)行訓(xùn)練,具體可以參考一下這篇論文:
https://arxiv.org/pdf/1512.09300.pdf
文中相關(guān)代碼鏈接:
英文參考:
雷鋒網(wǎng)相關(guān)閱讀:
深度學(xué)習(xí)全網(wǎng)最全學(xué)習(xí)資料匯總之模型介紹篇
Yann LeCun最新研究成果:可以幫助GAN使用離散數(shù)據(jù)的ARAE
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見(jiàn)轉(zhuǎn)載須知。