0
雷鋒網(wǎng)AI研習社按:神經(jīng)網(wǎng)絡長久以來的“黑盒”屬性,導致人們一直無法理解網(wǎng)絡的內部是如何運作的,針對這個困擾已久的問題,學界主要存在三種研究方向:數(shù)據(jù)歸因模式、特征解碼模式以及模型理解模式。
在近日的 AI 研習社大講堂上,清華大學的王宇龍就從模型理解的角度入手,為我們詳細介紹了如何通過發(fā)現(xiàn)網(wǎng)絡中的關鍵數(shù)據(jù)通路(critical data routing paths, CDRPs),更好地理解網(wǎng)絡。
公開課回放地址:
分享主題:利用關鍵數(shù)據(jù)通路理解神經(jīng)網(wǎng)絡
分享提綱:
方法介紹——distillation guided routing(DGR)算法
結果分析——路由的通路包含一定語義含義,幫助我們更好地理解網(wǎng)絡行為
應用領域——安全對抗樣本檢測
雷鋒網(wǎng) AI 研習社將其分享內容整理如下:
這張圖大概總結了一下當前這種網(wǎng)絡可解釋性的含義,我們常常把這種網(wǎng)絡、深度神經(jīng)網(wǎng)絡看作一個black box model,也就是一個黑盒模型,就像圖片中所描述的一樣。
我們想知道網(wǎng)絡里面運行的時候到底在做什么,學習完之后究竟學習到什么樣的知識,能夠對我們人類有什么樣的啟發(fā),所以對于網(wǎng)絡可理解性來說,目前有這樣三個主要的研究方向:
左邊綠色箭頭指向的,是說我將網(wǎng)絡所做的決策,或者說預測結果,直接歸因到數(shù)據(jù)層面,去分析樣本中哪些數(shù)據(jù)、或者說數(shù)據(jù)當中哪一塊區(qū)域,它的特征更加重要,我就直接拿來作為網(wǎng)絡行為的一種解釋。
右邊藍色箭頭指向的則是第二種方向也就是將網(wǎng)絡所學到的這種行為或者特征和人類已有的這種像知識圖譜的一些概念去進行聯(lián)系,使網(wǎng)絡在進行決策的時候,我們人類能夠理解它究竟在做些什么。
下面又是另外一種方向,我畫了一個顯微鏡的圖案來表示,我們要直接去探求網(wǎng)絡內部究竟在做什么,這個相對來說會比較難,因為它是一種更加直接的理解網(wǎng)絡的方法——我們的工作就屬于這個方向。
我再細說一下這三個方向。
第一類方法是歸因到數(shù)據(jù)層面的方法,我們叫做attribution methods,中文為“溯因方法”,它的一個大致的流程是說,給一張圖片,或者給一個數(shù)據(jù),網(wǎng)絡幫我們做出了這個決策,那么現(xiàn)在我要了解,究竟是數(shù)據(jù)的哪一面最終導致這個預測結果,然后我們就可以通過attribution methods去追溯和歸因到數(shù)據(jù)層面上去。像右上方第一張圖顯示的,就是幾種不同attribution methods的一個展示結果,主要通過一種叫做saliency maps(中文稱作“顯著圖”)的方式來進行展示,圖片上面可以看到像某種熱力圖一樣,顏色越深的,代表這一部分區(qū)域所做的貢獻越大,對最終預測結果起著最關鍵的影響。
第二類方法是歸因到知識上去,這種被稱作feature decoding的方式(中文“特征解碼”),它將網(wǎng)絡中間的中間層的特征給解碼出來,然后轉換到一個可理解的概念上去。換句話說,我在預測結果的同時,同時產生一段文本去解釋這個預測結果產生的原因。比如說右面的第二張圖,它在預測每一種鳥類的同時,也給出了這些鳥類局部的一些特征,比如說這個鳥之所以是某種鳥類,是因為它的脖子或者頭部或者嘴巴具備什么樣的特征,是一個比較接近人類的這樣一種解釋方式。不過,這種解釋方法回避了模型本身的解釋,也就是說產生這種文本的解釋是用另外一種網(wǎng)絡來去訓練的,那么產生解釋的網(wǎng)絡又該如何去理解呢?
第三類方法是直接去理解模型本身的行為。這一類方法目前沒有一個統(tǒng)一的范式,主要靠大家從多種不同的角度來做解釋。比如說我本來有一個很深的網(wǎng)絡,或者說一個很復雜的模型,那么我通過像知識蒸餾或者說模仿這種行為,去訓練一個更容易理解的模型,通過這個更容易理解的模型,針對某個局部面再去模仿原來模型的行為,是一種基于局部的一種解釋。或者說從一開始設計這種網(wǎng)絡的時候,就是要設計成一個可解釋的模型,比如說每一步都分別對應一種語義含義行為的設計方法。
關鍵數(shù)據(jù)通路
我們的工作主要是通過發(fā)現(xiàn)網(wǎng)絡中的關鍵數(shù)據(jù)通路(critical data routing paths,CDRPs),更好地去理解網(wǎng)絡。我們從之前類似網(wǎng)絡壓縮的一些工作中發(fā)現(xiàn),網(wǎng)絡中其實存在很多冗余,并不是所有的節(jié)點或者說神經(jīng)元都被利用到,這些神經(jīng)元即便刪去,也不會影響最終的預測結果。因此我們認為,每個樣本進來的時候,網(wǎng)絡都只是利用了其中一部分的節(jié)點或者說通路來完成最終預測的。
我們的工作希望能夠發(fā)掘出這些通路的特征或者說規(guī)律。在定義關鍵數(shù)據(jù)通路以前,我們首先要定義的是關鍵節(jié)點,因為關鍵數(shù)據(jù)通路實際上是由這些通路上的關鍵節(jié)點所組成的。比如像右面這張圖顯示的,共有三個卷積層,每一層紅色代表的都是不重要的可以進行刪除的節(jié)點,而綠色則代表的關鍵節(jié)點。通過連接每一層的關鍵節(jié)點,我們也就組成了所謂的關鍵數(shù)據(jù)通路。
所謂的關鍵節(jié)點,我們可以理解成輸出的關鍵channel,比如像一個卷積層,它的輸出是個三維的增量,在長寬兩維上我們認為是一種空間上的信息保留,而channel維就是一個第三維的信息保留,我們認為它是包含了這種語義含義的,或者說是代表這種節(jié)點的概念。如果一個通道全部被置為零,對最終預測結果產生了很大影響的話,這個節(jié)點就是關鍵節(jié)點。
從方法上來說,我們首先引入了control gates(中文為控制門)的概念,這個想法受到過去的模型壓縮、模型減枝方法的一些啟發(fā),channel-wise control gates說的是通道維上的每一維或者每一個通道都去關聯(lián)一個 lambda(這個lambda是一個標量值,我們認為這個lambda就是一個control gates),一旦這個 lambda 為 0,最后的優(yōu)化問題也同樣顯示為 0 的話,就默認它是一個不重要的節(jié)點,完全可以刪除掉。如果它是一個帶 0 的值,之后的值我們是不做限制的,因為值的大小代表了它在這一次預測中的重要性。
我們應該如何求解 lambda 的優(yōu)化目標是什么呢?其實我們是借鑒了這個所謂“知識蒸餾”的概念(Hinton在2015年所提出),也就是說網(wǎng)絡在進行預測的時候,它所輸出的概率分布不只是包含預測結果,還包括了隱藏知識——網(wǎng)絡里被認為包含其他類的概念有多少。我們的目標是要在刪掉這些不重要的節(jié)點以后,網(wǎng)絡的數(shù)據(jù)概率值能夠盡量接近原始的網(wǎng)絡數(shù)據(jù)概率值。
同一時間,我們要加上一定的正則項去約束control gate:一方面約束它大于0,非負數(shù);另一方面則要約束它具備“稀疏”的特性。右邊的優(yōu)化目標表示的是第一部分花體 L 的 loss function,度量的是兩次網(wǎng)絡輸出的概率分布距離。
第一項這個f_\theta(x),是原始網(wǎng)絡在獲得樣本后的一個輸出概率分布,是針對單樣本來說的(每次只考慮一個樣本)。第二項加了一個帶有 lambda 的式子,代表的是引入了control gate后,“減枝”網(wǎng)絡的輸出概率分布和原始網(wǎng)絡的輸出概率分布的接近程度——L。L 實際上就是一個 cross entropy,衡量兩個概率分布的距離。
接下來說一下lambda的約束限制。
第一項是要求它“非負”。lambda可以為 0,代表的是被完全抑制掉、刪除掉(假設該值大于0,小于1,說明比原來的響應相對要小一點,要是大于 0 且大于 1,說明它比原來的響應要更高一點,有一點放大的作用),但是我們不能讓它變?yōu)樨摂?shù),因為負數(shù)相當于把整個 channel 的 activations 的全部正負號都交換了,相當于所有的值都取了一個相反數(shù),我們認為這樣對原始網(wǎng)絡輸出值的分布范圍會有較大影響,且會對最終行為存在較大干擾。所以我們做網(wǎng)絡可解釋性的首要條件,是在保證盡量少改動的情況下去解釋當前網(wǎng)絡,一旦引入過多額外的干擾,你就很難保證說現(xiàn)在的解釋對于原來的網(wǎng)絡還是成立的。
第二項是要求它具有一定的稀疏性,這個和已有的一些“稀疏學習”的部分主張是吻合的,可以理解為越稀疏的模型,它將這種不同的屬性都進行解耦并取了關鍵屬性,就越發(fā)具備可理解性。
路徑的表示
我們以上所說的是 distillation guided routing(DGR)的一個大致方法。接下來我再說一下,如何對最終尋找到的路徑進行表示。
每次優(yōu)化完以后,每層它都有一個contol gate value,由粗體的lambda表示(大K表示的是網(wǎng)絡擁有K層這樣一個概念),只要將所有的control gate value拼接成一個最終長的向量,就是我們對相關路徑的一個表示。因為我們可以直接對長的向量使用tresholding這一種取閾值的方法,來獲得最終的critical nodes。比如說我認為大于0.5的才是真正的critical nodes,小于0.5的則不是,那我們可以通過取義值,得到一個最終的二值mask,那么它就代表了哪些可以被刪除,哪些可以被保留。
我們在后來的實驗中發(fā)現(xiàn),這一種表示包含了非常豐富的信息——如果不取義值,只將它原始優(yōu)化出來的浮點值保留下來的話,網(wǎng)絡在進行預測的時候,我們將發(fā)現(xiàn)更加豐富的功能性過程(可以把它視為一種新的activations,網(wǎng)絡響應都是一層一層傳到最高層,最高層的feature就可以看成一個響應,我們相當于側面在channel維上去引入了新的特征表示)。
接下來我詳細說一下這一頁(PPT第4頁)的優(yōu)化問題,我們應該如何進行求解。
求解的方法其實很簡單,就是通過梯度下降算法,每一次根據(jù)優(yōu)化目標對control gate value進行求導(原始網(wǎng)絡的權重值都是固定不變的)。所以我們解釋一些已有模型(比如像VGG,Alexnet, ResNet),都是通過引入并求解control gate value,接下來當我們再去解釋或者優(yōu)化時就會非常簡單,因為它需要更新的參數(shù)非常少,比如我們在實驗中只需設置30個iteration,就能得到一個很好的解釋結果。
在優(yōu)化的過程當中,這些引入control gate value的網(wǎng)絡預測,比如說top-1 prediction,也就是那個最大類別的響應,要和原始網(wǎng)絡的預測保持一致。比如說原始網(wǎng)絡它看圖片預測出來是狗,那么新的網(wǎng)絡也要保障它的預測結果是狗。至于其他類別的響應,我們則不做要求,因為既然是distillation,肯定就會存在一定程度的不同??偟膩碚f,你在解釋的網(wǎng)絡的時候,不該改變網(wǎng)絡的原始行為。
接下來說一下對抗樣本檢測,我們之所以會將該方法用到這個任務上去,是因為我們發(fā)現(xiàn),我們所找到的這個feature對于對抗樣本檢測有很大的幫助。
首先什么叫對抗樣本?非常簡單,看下面這張圖,第一個是大熊貓,它被輸入進一個標準的網(wǎng)絡里面,被顯示為55.7%的一個預測信度,但是我在中間加了這個噪聲圖,最后得到一張新的圖片,再把這張新的圖片輸入到網(wǎng)絡里時,結果預測為“長臂猿”,同時擁有很高的信度,達到99.3%。從人的視角來看,新生成的圖片跟原始圖片并沒有太大差別,這種現(xiàn)象我們就叫做對抗樣本,也就是說新的圖片對網(wǎng)絡而言是具有“對抗性”或者說“攻擊性”的。
對抗樣本現(xiàn)象引發(fā)了人們對網(wǎng)絡可理解性的關注,因為網(wǎng)絡的“黑盒”特性使我們無從得知它為什么會預測正確或者預測錯誤,而且這種錯誤的特性還特別不符合人類的直覺,人類無法理解說這樣一個噪聲為何能夠引起這么大的一個改變。因此現(xiàn)在有大量的工作就是在做對抗的樣本攻擊以及對抗的樣本防御。我們的組在這方面之前也是做了很多工作,在去年的NIPS 2017年有一個對抗攻防比賽,我們的組在攻擊和對抗方面都做到了第一。
我們接下來會利用關鍵數(shù)據(jù)通路去進行對抗樣本檢測。我們的思考是這樣的,兩種樣本在輸入端從人類的感覺上看來差別并不大,這也意味著前幾層所走的網(wǎng)絡關鍵路徑按理來說差別不大。只是對抗樣本的噪聲越往高層走,它被干擾的程度不知因何被放大了,才導致路徑開始偏離,最終走到另一個類別上去,導致預測結果完全不一樣。
那么我們其實可以訓練出某種分類器,專門用來檢測真實樣本與對抗樣本的關鍵數(shù)據(jù)通路。如果查出來差別,就有一定的概率檢測出它究竟是真實樣本還是對抗樣本。
接下來說一下實驗的部分。
我們首先做了一個定量實驗來檢驗方法的有效性,這個實驗叫做post-hoc interpretation(中文是“事后解釋”),就是針對網(wǎng)絡最終的預測結果再做一次解釋(一張圖片只解釋一個)。在實驗中,數(shù)據(jù)集采用來自 ImagNet 的五萬張 validation images,訓練網(wǎng)絡則用的AlexNet、VGG-16、ResNet-50等。
需要說明的是,實驗只聚焦在卷積層,因為類似 VGG-16、ResNet 的 fully-connected layers,我們認為是一個最終的分類器,所以不考慮這一層面的關鍵數(shù)據(jù)通路。再者,ResNet 的網(wǎng)絡層較深,我們也不可能將所有的卷積層都考慮進來,太冗余且沒有必要。所以對于 ResNet,我們的處理方法就是只關注 ResBlocks 的輸出,而這個 Block 的量相對較少,我們再根據(jù)這些 Block 的輸出去觀察它所利用到的關鍵節(jié)點。
給大家介紹這個實驗,當我們找到關鍵節(jié)點以后,我們將有序地抑制掉一部分的關鍵節(jié)點,然后再觀察它對網(wǎng)絡最終造成多大程度的影響。
在操作上有兩種方式,一種是先刪除control gate value最大的,我們稱作Top Mode,或者反過來,我們先刪除control gate value最小的,這兩種刪除方式最后引起網(wǎng)絡性能下降的一個曲線,在下面這兩張圖上展示。(注:control gate value越大,那么說明它的影響/重要性越大)
可以看下上邊這張圖,橫坐標顯示的是被抑制的關鍵節(jié)點比例,我們可以看到,只有1%的關鍵節(jié)點被抑制(通道置為0),原模型的top-1 acc還有top-5 acc就會面臨非常劇烈的下降,分別是top-1 acc下降百分之三十多,top-5 acc下降百分之二十多。
也就是說,只要1%的關鍵節(jié)點,還不是所有節(jié)點(關鍵節(jié)點其實只占網(wǎng)絡節(jié)點的百分之十左右)被刪除的話,網(wǎng)絡性能就會面臨劇烈的下降。在某種程度上來說,這個結果證明了我們所尋找到的關鍵節(jié)點的有效性。
節(jié)點的語義含義
其實我們更重要的工作成果在這一部分,那就是我們所尋找的節(jié)點其實包含了一定的語義含義,這是網(wǎng)絡可解釋性領域一直在關注的。首先我們會關注層內的路由節(jié)點的語義含義,比如說一個樣本進來,它經(jīng)過每一層,我們會先看每層有哪些節(jié)點,然后再看它擁有什么樣的語義含義。
我們在上方展示了5張圖,每張圖上有五萬個點,對應的是五萬張圖片。不過我們都知道,網(wǎng)絡里的channel維都是像512、256這樣一個向量,我們怎么樣可以把這五萬個向量之間的相似性更直觀的展示出來呢?我們最終采用的是t-SNE方法,類似于說將一些向量投影到二維平面上去。投影的結果就像下面5張圖展示的,顏色代表類別,同樣類別的圖片所對應的點,顏色都是相同的。我們會看到,隨著層數(shù)加深,它的點也隨著變得更加稀疏起來,然而實際上點的數(shù)量是沒有改變的,依然還是五萬張,五萬個點。
為什么會呈現(xiàn)這樣一個稀疏或者分離的現(xiàn)象呢?因為同個類別的點都聚集在同一處,距離也就變得更加靠近,所以看起來中間有很多空白的部分。這也說明,在越高層的地方,同個類別所走過的節(jié)點或路徑會越加相似,簡單來說就是貓走貓的路徑,狗走狗的節(jié)點。
這張圖全面地展示了VGG-16里13個卷積層每一層關鍵節(jié)點的二維圖,我們會看到,在底層里各個類別都混雜在一起,沒有特別明顯的區(qū)分,然而隨著層數(shù)變高,顏色會開始有規(guī)律地聚集到同一個區(qū)域,說明這些類別開始各走各自的路徑。越到高層越稀疏。
在知曉每層節(jié)點的語義情況下,我們想進一步了解由這些節(jié)點連接構成的關鍵數(shù)據(jù)路徑,究竟具備什么樣的語義特征。于是我們做了一個實驗,針對類內樣本(樣本都是屬于同一類的),我們將它們所有的CDRPs的特征表示拿去做一個層次化聚集聚類,看看它們的CDRPs表征究竟有什么相似性。
上面的樹形圖,表示每個樣本之間的相似程度,越往底層,兩個樣本就越靠近,而越往高層,就越慢被聚到一起。縱坐標代表了兩個樣本的距離,里面的相似顏色代表的是他們被聚成一個子類了。我們在看這些圖片的聚類情況會發(fā)現(xiàn),如果圖片特征很相似,那么他們的CDRPs聚類結果也是很相似的。
另外還有一個很有趣的發(fā)現(xiàn),像左邊這50張圖,應該是某種魚,魚的圖片有這樣的一些分布規(guī)律:魚處在中間位置,采用的是橫拍模式,另外還有一類圖片,則是垂釣愛好者手里捧著魚蹲在地上拍照。我們發(fā)現(xiàn),這兩類圖片都被歸到魚這個類別,然而實際上圖片的特征存在很大的不同。
目前看來網(wǎng)絡 features 是檢驗不出來這種差異的,因為它們最終都被預測為魚這一個類別。然而我們的 CDRPs 表征就細致發(fā)現(xiàn)了其中的差異,就體現(xiàn)在兩者所走的關鍵路徑其實是不一樣的。
像左邊這張圖有紅框框起來的4張圖,其實是通過 CDRPs 的所分析出來的類似 outliner 的圖片。如果仔細看,會發(fā)現(xiàn)其中有一張圖片是一個人抱著魚,但是方向卻被旋轉了90度,按理來說這是一個類似于噪聲一樣的存在,然而我們的CDRPs卻能把它歸類到魚的類別,只是所走的關鍵路徑和其他樣本有著不一樣的特征,因此把它給聚類出來,變成一個發(fā)現(xiàn)。
像右面則是一個白頭鷹,中間第二個的聚類都是聚焦于鷹的頭部,而第三類則聚焦于鷹站在樹上,而左邊這個是單獨的 outliner,都是一些非常不清晰的圖像。
這里展示的是更多的一些結果。
對抗樣本檢測的應用
最后呢,我們嘗試用來做對抗樣本檢,像我之前所說的,正常樣本與對抗樣本,從輸入端來說沒有太大差別,但是從最后的預測結果來說,是有很大區(qū)別的。在我們看來,是兩張圖片在網(wǎng)絡里所走的關鍵路徑逐漸有了分歧,導致了最終的分開。
我們先看上方左邊這張圖,這張圖首先是一個正常樣本,加了噪聲以后,預測結果由貓變成了車輪。我們該如何體現(xiàn)這兩種關鍵路徑的區(qū)別呢?我們主要算的是這兩個樣本在不同層上所走的關鍵節(jié)點的相關性,我們先找到每一層各自的關鍵節(jié)點,然后有一個向量,然后根據(jù)這個向量去推算相關系數(shù)來表示兩個路徑的相似性。
上面這張圖里的橘紅線代表了這個相似程度,可以發(fā)現(xiàn)對抗樣本對于正常樣本的相似性是隨著層數(shù)增高的,而大致趨勢是逐漸下降的。簡單來說,高層的相似性要比底層小得多。
我們又算了下對抗樣本對于目標類別,它們這些樣本所走的這些關鍵路徑的相似性,接著計算車輪這一類別樣本的路徑相似性。我們找來車輪這一類別的50張圖片,將這50張樣本的每個路徑都算一遍相關系數(shù),上面這張圖叫做violinplot,展示的就是這50個系數(shù)的分布展示情況。
由于每個樣本之間存在差異,所有顯示結果有的高有的低。最后發(fā)現(xiàn),隨著層數(shù)加深,目標類別的相似系數(shù)會越來越高。比如在最高層的地方,violinplot的最低點都要比原始橘紅色的點要高。這也就是說,對抗樣本在高層所走的路徑和目標類別所走的路徑是很相似的,后面幾張圖也是在闡述這樣一個情況,具體的情況大家可以細致地去參考一下論文。
接著我們又去做對抗樣本檢測,檢測方法是通過取一些正常的樣本,比如說從 ImageNet 里挑出一千種類別,每一種類別取出1張圖片(有些實驗取出5張圖片,有些取出10張圖片等等),然后每一張圖片我們都產生一個對抗樣本(用的 FGSM 算法),然后作為訓練集,接著用我們的算法去算它的 CDRPs 表征,再取一個二分類的分類器來檢測和判斷這個路徑是屬于正常樣本還是對抗樣本。
訓練結束以后,我們就用這個分類器來做對抗樣本檢測,換句話說,我們自己構造了一個包含正常樣本與對抗樣本的數(shù)據(jù)集,然后用訓練所得到的分類器來預測哪一個是正常樣本,哪一個是對抗樣本。
下面的表格展示了我們不同實驗室的實驗結果,這個值如果越高,越近于1,就說明這個分類越完美。隨著訓練樣本的增加,分類結果變得越來越好之余,不同的二分類器所能達到的水準還是比較相似的(可能使用像 gradient boosting 或者 random forest 的方法會更好一些)。
結論
最后總結一下我今天所分享的內容,首先是我們提了一個全新的角度來進行網(wǎng)絡可解釋性,也就是通過尋找關鍵數(shù)據(jù)路徑,我們會發(fā)現(xiàn)有一些語義含義包含在數(shù)據(jù)路徑里頭。包括像層內節(jié)點,它會有一定的區(qū)分能力,而且隨著層數(shù)的增高,區(qū)分能力會逐漸加深。
同一時間,關鍵路徑又體現(xiàn)出類內樣本不同的輸入特征,有助于幫助我們發(fā)現(xiàn)一些數(shù)據(jù)集當中的樣本問題。
最后我們提了一個新的對抗樣本檢測算法,通過利用CDRPs的特征來檢測它究竟是真實樣本還是對抗樣本。CDRPs反映出對抗樣本在高層與正常樣本的距離較遠,在底層與正常樣本距離較近這樣一種特征模式,利用這種特征模式我們可以進行檢測,達到一個很好的防御效果。
以上就是本期嘉賓的全部分享內容。更多公開課視頻請到雷鋒網(wǎng)AI研習社社區(qū)(https://club.leiphone.com/)觀看。關注微信公眾號:AI 研習社(okweiwu),可獲取最新公開課直播時間預告。
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權禁止轉載。詳情見轉載須知。