0
本文作者: skura | 2019-02-20 19:12 |
雷鋒網(wǎng) AI 科技評(píng)論按,用對(duì)抗性邊緣學(xué)習(xí)修復(fù)生成圖像是一種新的圖像修復(fù)方法,它可以更好地復(fù)制填充區(qū)域,它的細(xì)節(jié)部分展現(xiàn)了開發(fā)者對(duì)藝術(shù)工作者工作方式的理解:線條優(yōu)先,顏色次之。對(duì)應(yīng)的論文在 arxiv 上可以查看:https://arxiv.org/abs/1901.00212。
文中提出了一種 2 階對(duì)抗式邊緣連接模型,該模型由一個(gè)邊緣生成器和一個(gè)圖像完成網(wǎng)絡(luò)組成。邊緣生成器先描繪出圖像缺失區(qū)域(規(guī)則和不規(guī)則)的邊緣,圖像完成網(wǎng)絡(luò)先驗(yàn)使用描繪出的邊緣填充缺失區(qū)域。論文對(duì)該系統(tǒng)進(jìn)行了詳細(xì)的描述。
(a)輸入有缺失區(qū)域的圖像,缺失區(qū)域用白色表示。(b)計(jì)算得到的邊緣,使用 Canny 邊緣檢測(cè)器計(jì)算(針對(duì)可用區(qū)域)黑色繪制的邊緣;而藍(lán)色顯示的邊緣由邊緣生成器網(wǎng)絡(luò)描繪。(c)擬用方法的圖像修復(fù)結(jié)果。
準(zhǔn)備:
Python 3
PyTorch 1.0
NVIDIA GPU + CUDA cuDNN
安裝:
復(fù)制下面這個(gè) repo:
git clone https://github.com/knazeri/edge-connect.gitcd edge-connect
cd edge-connect
從 http://pytorch.org 安裝 PyTorch 及其相關(guān)依賴。
安裝 python 文件:
pip install -r requirements.txt
數(shù)據(jù)集
1. 圖像
這里使用 Places2, CelebA 以及 Paris Street-View 數(shù)據(jù)集。從官網(wǎng)下載數(shù)據(jù)集,在整個(gè)數(shù)據(jù)集上訓(xùn)練模型。
下載完成后,運(yùn)行 scripts/flist.py 這個(gè)文件來生成訓(xùn)練、測(cè)試和驗(yàn)證集文件列表。例如,要在 Places2 數(shù)據(jù)集上生成訓(xùn)練集文件列表,請(qǐng)運(yùn)行:
mkdir datasets
python ./scripts/flist.py --path path_to_places2_train_set --output ./datasets/places_train.flist
2 .不規(guī)則掩膜
我們的模型是在 Liu 等人提供的不規(guī)則掩模數(shù)據(jù)集上進(jìn)行訓(xùn)練的,你可以從他們的網(wǎng)站上下載公開的不規(guī)則掩膜數(shù)據(jù)集。雷鋒網(wǎng)
請(qǐng)使用 scripts/flist.py 生成上述訓(xùn)練、測(cè)試和驗(yàn)證集掩膜文件列表。
開始
使用以下鏈接下載預(yù)先訓(xùn)練的模型,并將其復(fù)制到./checkpoints 目錄下。
或者,你可以運(yùn)行以下腳本自動(dòng)下載預(yù)訓(xùn)練的模型:
bash ./scripts/download_model.sh
1 .訓(xùn)練
要訓(xùn)練模型,請(qǐng)創(chuàng)建一個(gè)類似于示例配置文件的 config.yaml 文件,并將其復(fù)制到檢查點(diǎn)目錄下。有關(guān)模型配置的詳細(xì)信息,請(qǐng)參閱配置指南。
EdgeConnect 的訓(xùn)練分為三個(gè)階段:1)邊緣模型的訓(xùn)練;2)內(nèi)部模型的訓(xùn)練;3)聯(lián)合模型的訓(xùn)練。訓(xùn)練模型:
python train.py --model [stage] --checkpoints [path to checkpoints]
例如,要在./checkpoints/places2 目錄下的 places2 數(shù)據(jù)集上訓(xùn)練邊緣模型:
python train.py --model 1 --checkpoints ./checkpoints/places2
模型的收斂性因數(shù)據(jù)集而異。例如,Places2 數(shù)據(jù)集在兩個(gè)時(shí)期中的一個(gè)就能聚合,而較小的數(shù)據(jù)集(如 CelebA)則需要將近 40 個(gè)時(shí)期才能聚合。你可以通過更改配置文件中的 MAX_ITERS 值來設(shè)置訓(xùn)練迭代次數(shù)。雷鋒網(wǎng)
2 .測(cè)試
要測(cè)試模型,請(qǐng)創(chuàng)建一個(gè)與示例配置文件類似的 config.yaml 文件,并將其復(fù)制到檢查點(diǎn)目錄下。
你可以在所有三個(gè)階段上測(cè)試模型:邊緣模型、內(nèi)部模型和聯(lián)合模型。在每種情況下,都需要提供一個(gè)輸入圖像(帶掩膜的圖像)和一個(gè)灰度掩膜文件。請(qǐng)確保掩膜文件覆蓋輸入圖像中的整個(gè)掩膜區(qū)域。測(cè)試模型:
python test.py \
--model [stage]
--checkpoints [path to checkpoints] \
--input [path to input directory or file] \
--mask [path to masks directory or mask file] \
--output [path to the output directory]
我們?cè)?/examples 目錄下提供了一些測(cè)試示例,請(qǐng)下載預(yù)訓(xùn)練模型并運(yùn)行:
python test.py \
--checkpoints ./checkpoints/places2
--input ./examples/places2/images
--mask ./examples/places2/masks
--output ./checkpoints/results
此腳本將在./examples/places2/images 中使用和./examples/places2/mask 對(duì)應(yīng)的掩膜圖像,并將結(jié)果保存在./checkpoints/results 目錄中。默認(rèn)情況下,test.py 腳本在階段 3 上運(yùn)行(--model=3)。
3 .評(píng)估
要評(píng)估模型,你需要首先在測(cè)試模式下對(duì) validation 集運(yùn)行模型,并將結(jié)果保存到磁盤上。我們提供了一個(gè)實(shí)用程序./scripts/metrics.py,使用 PSNR, SSIM 和 Mean Absolute Error 評(píng)估模型:
python ./scripts/metrics.py --data-path [驗(yàn)證集路徑] --輸出路徑 [模型輸出路徑]
要測(cè)量 FID 分?jǐn)?shù),請(qǐng)運(yùn)行./scripts/fid_score.py。我們利用了這里的 FID 的 pytorch 實(shí)現(xiàn),它使用了 pytorch 初始模型中的預(yù)訓(xùn)練權(quán)重。
python ./scripts/fid_score.py --path [驗(yàn)證集路徑, 模型輸出路徑] --gpu [要使用的 gpu id]
via:https://github.com/knazeri/edge-connect
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。