0
由于其對于原始數(shù)據(jù)潛在概率分布的強大感知能力,GAN 成為了當(dāng)下最熱門的生成模型之一。然而,訓(xùn)練不穩(wěn)定、調(diào)參難度大一直是困擾著 GAN 愛好者的老問題。本文是一份干貨滿滿的 GAN 訓(xùn)練心得,希望對有志從事該領(lǐng)域研究和工作的讀者有所幫助!
在當(dāng)下的深度學(xué)習(xí)研究領(lǐng)域中,對抗生成網(wǎng)絡(luò)(GAN)是最熱門的話題之一。在過去的幾個月里,關(guān)于 GAN 的論文數(shù)量呈井噴式增長。GAN 已經(jīng)被應(yīng)廣泛應(yīng)用到了各種各樣的問題上,如果你之前對此并不太了解,可以通過下面的 Github 鏈接看到一些酷炫的 GAN 應(yīng)用:
時至今日,我已經(jīng)閱讀了大量有關(guān) GAN 的文獻,但我還從來沒有自己動手實踐過。因此,在瀏覽了一些對人有所啟發(fā)的論文和 Github 代碼倉庫后,我決定親自嘗試訓(xùn)練一個簡單的 GAN。不出所料,我立刻就遇到了一些問題。
本文的目標(biāo)讀者是從 GAN 入門的熱愛深度學(xué)習(xí)的朋友。除非你走了大運,否則你自己第一次訓(xùn)練一個 GAN 的過程可能是非常令人沮喪的,而且需要花費好幾個小時才能做好。當(dāng)然,隨著時間的推移和經(jīng)驗的增長,你可能會漸漸善于訓(xùn)練 GAN。但是對于初學(xué)者來說,可能會犯一些錯,而且不知道該從哪里開始調(diào)試。在本文中,我想向大家分享我第一次從頭開始訓(xùn)練 GAN 時的觀察和經(jīng)驗教訓(xùn),希望本文可以幫助大家節(jié)省幾個小時的調(diào)試時間。
在過去的一年左右的時間里,深度學(xué)習(xí)圈子里的每個人(甚至一些沒有參與過深度學(xué)習(xí)相關(guān)工作的人),都應(yīng)該對 GAN 有所耳聞(除非你住在深山老林里、與世隔絕)。生成對抗網(wǎng)絡(luò)(GAN)是一種數(shù)據(jù)的生成式模型,主要以深度神經(jīng)網(wǎng)絡(luò)的形式存在。也就是說,給定一組訓(xùn)練數(shù)據(jù),GAN 可以學(xué)會估計數(shù)據(jù)的底層概率分布。這一點非常有用,因為我們現(xiàn)在可以根據(jù)學(xué)到的概率分布生成原始訓(xùn)練數(shù)據(jù)集中沒有出現(xiàn)過的樣本。如上面的鏈接所示,這催生了一些非常實用的應(yīng)用程序。
該領(lǐng)域的專家已經(jīng)提供了一些很棒的資源來解釋 GAN 以及它們的工作遠離,所以本文在這里不會重復(fù)他們的工作。但是為了保持文章的完整性,在這里對相關(guān)概念進行簡要的回顧。
GAN 模型概覽
生成對抗網(wǎng)絡(luò)實際上是兩個相互競爭的深度網(wǎng)絡(luò)。給定一個訓(xùn)練集 X(比如說幾千張貓的圖像),生成網(wǎng)絡(luò) G(x) 會將隨機向量作為輸入,并試圖生成與訓(xùn)練集中的圖像相類似的新圖像樣本。判別器網(wǎng)絡(luò) D(x) 則是一種二分類器,試圖將訓(xùn)練集 X 中「真實的」貓的圖像和由生成器生成的「假的」貓圖像區(qū)分開來。如此一來,生成網(wǎng)絡(luò)的職責(zé)就是學(xué)習(xí) X 中的數(shù)據(jù)的分布,這樣它就可以生成看起來真實的貓圖像,并確保判別器無法區(qū)分來自訓(xùn)練集的貓圖像和來自生成器的貓圖像。判別器則需要通過學(xué)習(xí)跟上生成器不斷進化、嘗試通過新的方式生成可以「騙過」判別器的「假的」貓圖像的步伐。
最終,如果一切順利,生成器(或多或少)會學(xué)到訓(xùn)練數(shù)據(jù)的真實分布,并變得非常善于生成看起來真實的貓圖像。而判別器則不能再將訓(xùn)練集中的貓圖像和生成的貓圖像區(qū)分開來。
從這個意義上說,這兩個網(wǎng)絡(luò)一直在努力確保對方不能很好地完成自己的任務(wù)。那么,這究竟是如何起作用的呢?
另一種看待 GAN 的方式是:判別器試圖通過高速生成器真實的貓圖像看起來是怎樣的,從而引導(dǎo)生成器。最終,生成器研究清楚了問題,開始生成看起來真實的貓圖像。訓(xùn)練 GAN 的方法類似于博弈論中的極大極小算法,兩個網(wǎng)絡(luò)試圖達到同時考慮二者的納什均衡。更多細節(jié),請參閱本文底部給出的參考資料。
下面,我們將繼續(xù)分析 GAN 的訓(xùn)練過程。為了簡單起見,我使用了「Keras+Tensorflow 后端」的組合,在 MNIST 數(shù)據(jù)集上訓(xùn)練了一個 GAN(確切地說是 DC-GAN)。這并不太困難,在對生成器和判別器網(wǎng)絡(luò)進行了一些小的調(diào)整之后,GAN 就可以生成清晰的 MNIST 圖像了。
生成的 MNIST 數(shù)字
如果你覺得 MNIST 中黑白數(shù)字沒那么有趣,那么生成各種物體和人的彩色圖片還很酷炫的。而這樣一來,問題就變得棘手了。在攻克了 MNIST 數(shù)據(jù)集之后,顯然下一步就是生成 CIFAR-10 圖像。經(jīng)過日復(fù)一日的超參數(shù)調(diào)參、改變網(wǎng)絡(luò)架構(gòu)、增添或刪除網(wǎng)絡(luò)層,我終于能夠生成出高質(zhì)量的和 CIFAR-10 類似的圖像。
使用 DC-GAN 生成的青蛙
使用 DC-GAN 生成的汽車
我最初使用了一個非常深的網(wǎng)絡(luò)(但是大多數(shù)情況下性能并不佳),最后使用的真正有效的網(wǎng)絡(luò)卻十分簡單。在我開始調(diào)整網(wǎng)絡(luò)和訓(xùn)練過程時,經(jīng)過 15 個 epoch 的訓(xùn)練后生成的圖像從這樣:
變成了這樣:
最終的結(jié)果是:
下面,我基于自己犯過的錯誤以及一直以來學(xué)到的東西,總結(jié)出了 7 大規(guī)避 GAN 訓(xùn)練陷阱的法則。所以,如果你是一個 GAN 新兵,在訓(xùn)練中沒有很多成功的經(jīng)驗,也許看看下面的幾個方面可能會有所幫助:
鄭重聲明:下面我只是列舉出了我嘗試過的事情以及得到的結(jié)果。并且,我并不是說已經(jīng)解決了所有訓(xùn)練 GAN 的問題。
更大的卷積和可以覆蓋前一層特征圖中的更多像素,因此可以關(guān)注到更多的信息。在 CIFAR-10 數(shù)據(jù)集上,5*5 的卷積核可以取得很好的效果,而在判別器中使用 3*3 的卷積核會使判別器損失迅速趨近于 0。對于生成器來說,我們希望在頂層的卷積層中使用較大的卷積核來保持某種平滑性。而在較底層,我并沒有發(fā)現(xiàn)改變卷積核的大小會帶來任何關(guān)鍵的影響。
卷積核的數(shù)量的提升會大幅增加參數(shù)的數(shù)量,但通常我們確實需要更多的卷積核。我?guī)缀踉谒械木矸e層中都使用了 128 個卷積核。特別是在生成器中,使用較少的卷積核會使得最終生成的圖像太模糊。因此,似乎使用更多的卷積核有助于捕獲額外的信息,最終會提升生成圖像的清晰度。
盡管這一開始似乎有些奇怪,但是對我來說,改變標(biāo)簽的分配是一個重要的技巧。
如果你正在使用「真實圖像=1」、「生成圖像=0」的標(biāo)簽分配方法,將標(biāo)簽反轉(zhuǎn)過來會對訓(xùn)練有所幫助。正如我們會在后文中看到的,這有助于在迭代早期梯度流的傳播,也有助于訓(xùn)練的順利進行。
這一點在訓(xùn)練判別器時極為重要。使用硬標(biāo)簽(非 1 即 0)幾乎會在早期就摧毀所有的學(xué)習(xí)進程,導(dǎo)致判別器的損失迅速趨近于 0。我最終用一個 0-0.1 之間的隨機數(shù)來代表「標(biāo)簽 0」(真實圖像),并使用一個 0.9-1 之間的隨機數(shù)來代表 「標(biāo)簽 1」(生成圖像)。在訓(xùn)練生成器時則不用這樣做。
此外,添加一些帶噪聲的標(biāo)簽是有所幫助的。在我的實驗過程中,我將輸入給判別器的圖像中的 5% 的標(biāo)簽隨機進行了反轉(zhuǎn),即真實圖像被標(biāo)記為生成圖像、生成圖像被標(biāo)記為真實圖像。
批量歸一化當(dāng)然對提升最終的結(jié)果有所幫助。加入批量歸一化可以最終生成明顯更清晰的圖像。但是,如果你錯誤地設(shè)置了卷積核的大小和數(shù)量,或者判別器損失迅速趨近于 0,那加入批量歸一化可能也無濟于事。
在網(wǎng)絡(luò)中加入批量歸一化(BN)層后生成的汽車
為了便于訓(xùn)練 GAN,確保輸入數(shù)據(jù)有類似的特性是很有用的。例如,與其在 CIFAR-10 數(shù)據(jù)集中所有 10 個類別上訓(xùn)練 GAN,不如選出一個類別(比如汽車或青蛙),訓(xùn)練 GAN 根據(jù)此類數(shù)據(jù)生成圖像。DCGAN 的另外一些變體可以很好地學(xué)會根據(jù)若干個類生成圖像。例如,條件 GAN(CGAN)將類別標(biāo)簽一同作為輸入,以類別標(biāo)簽為先驗條件生成圖像。但是,如果你從一個基礎(chǔ)的 DCGAN 開始學(xué)習(xí)訓(xùn)練 GAN,最好保持模型簡單。
如果可能的話,請監(jiān)控網(wǎng)絡(luò)中的梯度和損失變化。這可以幫助我們了解訓(xùn)練的進展情況。如果訓(xùn)練進展不是很順利的話,這甚至可以幫助我們進行調(diào)試。
理想情況下,生成器應(yīng)該在訓(xùn)練的早期接受大梯度,因為它需要學(xué)會如何生成看起來真實的數(shù)據(jù)。另一方面,判別器則在訓(xùn)練早期則不應(yīng)該總是接受大梯度,因為它可以很容易地區(qū)分真實圖像和生成圖像。當(dāng)生成器訓(xùn)練地足夠好時,判別器就沒有那么容易區(qū)分真實圖像和生成圖像了。它會不斷發(fā)生錯誤,并得到較大的梯度。
我在 CIFAR-10 中的汽車上訓(xùn)練的幾個早期版本的 GAN 有許多卷積層和批量歸一化層,并且沒有進行標(biāo)簽反轉(zhuǎn)。除了監(jiān)控梯度的變化趨勢,監(jiān)控梯度的大小也很重要。如果生成器中網(wǎng)絡(luò)層的梯度太小,學(xué)習(xí)可能會很慢或者根本不會進行學(xué)習(xí)。
生成器頂層的梯度(x 軸:minibatch 迭代次數(shù))
生成器底層的梯度(x 軸:minibatch 迭代次數(shù))
判別器頂層的梯度(x 軸:minibatch 迭代次數(shù))
判別器底層的梯度(x 軸:minibatch 迭代次數(shù))
生成器最底層的梯度太小,無法進行任何的學(xué)習(xí)。判別器的梯度自始至終都沒有變化,說明判別器并沒有真正學(xué)到任何東西。現(xiàn)在,讓我們將其與帶有上述所有改進方案的 GAN 的梯度進行對比,改進后的 GAN 得到了很好的、與真實圖像看起來類似的圖像:
生成器頂層的梯度(x 軸:minibatch 迭代次數(shù))
生成器底層的梯度(x 軸:minibatch 迭代次數(shù))
判別器頂層的梯度(x 軸:minibatch 迭代次數(shù))
判別器底層的梯度(x 軸:minibatch 迭代次數(shù))
此時生成器底層的梯度明顯要高于之前版本的 GAN。此外,隨著訓(xùn)練的進展,梯度流的變化趨勢與預(yù)期一樣:生成器在訓(xùn)練早期梯度較大,而一旦生成器被訓(xùn)練得足夠好,判別器的頂層就會維持高的梯度。
可能是由于我缺乏耐心,我犯了一個愚蠢的錯誤——在進行了幾百個 minibatch 的訓(xùn)練后,當(dāng)我看到損失函數(shù)仍然沒有任何明顯的下降,生成的樣本仍然充滿噪聲時,我終止了訓(xùn)練。比起等到訓(xùn)練結(jié)束才意識到網(wǎng)絡(luò)什么都沒有學(xué)到,重新開始工作、節(jié)省時間確實讓人心動。GAN 的訓(xùn)練時間很長,初始的少量的損失值和生成的樣本幾乎不能顯示出任何趨勢和進展。在結(jié)束訓(xùn)練過程并調(diào)整設(shè)置之前,還是很有必要等待一段時間的。
這條規(guī)則的一個例外情況是:如果你看到判別器損失迅速趨近于 0。如果發(fā)生了這種情況,幾乎就沒有任何機會補救了。最好在對網(wǎng)絡(luò)或訓(xùn)練過程進行調(diào)整后重新開始訓(xùn)練。
最終的 GAN 的架構(gòu)如下所示:
希望本文中的這些建議可以幫助所有人從頭開始訓(xùn)練他們的第一個 DC-GAN。下面,本文將給出一些包含大量關(guān)于 GAN 的信息的學(xué)習(xí)資源:
GAN 論文參考:
「Generative Adversarial Networks」
「Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks」
「Improved Techniques for Training GANs」
其他參考鏈接:
「Training GANs: Better understanding and other improved techniques」
「NIPS 2016 GAN 教程」
「Conditional GAN」
本文最終版 GAN 的 Keras 代碼鏈接如下:
https://github.com/utkd/gans/blob/master/cifar10dcgan.ipynb?source=post_page
via https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9 雷鋒網(wǎng)雷鋒網(wǎng)雷鋒網(wǎng)
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。