0
雷鋒網(wǎng) AI 研習(xí)社消息,相信大家對于「深度學(xué)習(xí)教父」Geoffery Hinton 在去年年底發(fā)表的膠囊網(wǎng)絡(luò)還記憶猶新,在論文 Dynamic Routing between Capsules 中,Hinton 團(tuán)隊提出了一種全新的網(wǎng)絡(luò)結(jié)構(gòu)。為了避免網(wǎng)絡(luò)結(jié)構(gòu)的雜亂無章,他們提出把關(guān)注同一個類別或者同一個屬性的神經(jīng)元打包集合在一起,好像膠囊一樣。在神經(jīng)網(wǎng)絡(luò)工作時,這些膠囊間的通路形成稀疏激活的樹狀結(jié)構(gòu)(整個樹中只有部分路徑上的膠囊被激活)。這樣一來,Capsule 也就具有更好的解釋性。
在實驗結(jié)果上,CapsNet 在數(shù)字識別和健壯性上都取得了不錯的效果。詳情可以參見終于盼來了Hinton的Capsule新論文,它能開啟深度神經(jīng)網(wǎng)絡(luò)的新時代嗎?
日前,該論文的第一作者 Sara Sabour 在 GitHub 上公布了論文代碼,大家可以馬上動手實踐起來。雷鋒網(wǎng) AI 研習(xí)社將教程編譯整理如下:
所需配置:
TensorFlow(點(diǎn)擊 http://www.tensorflow.org 進(jìn)行安裝或升級)
NumPy (詳情點(diǎn)擊 http://www.numpy.org/ )
GPU
執(zhí)行 test 程序,來驗證安裝是否正確,諸如:
python layers_test.py
快速 MNIST 測試:
下載并提取 MNIST tfrecord 到 $DATA_DIR/ 下:
https://storage.googleapis.com/capsule_toronto/mnist_data.tar.gz
下載并提取 MNIST 模型 checkpoint 到 $CKPT_DIR 下:
https://storage.googleapis.com/capsule_toronto/mnist_checkpoints.tar.gz
python experiment.py --data_dir=$DATA_DIR/mnist_data/ --train=false \
--summary_dir=/tmp/ --checkpoint=$CKPT_DIR/mnist_checkpoint/model.ckpt-1
快速 CIFAR10 ensemble 測試:
下載并提取 cifar10 二進(jìn)制文件到 $DATA_DIR/ 下:
下載并提取 cifar10 模型 checkpoint 到 $CKPT_DIR 下:
https://storage.googleapis.com/capsule_toronto/cifar_checkpoints.tar.gz
將目錄($DATA_DIR)作為 data_dir 來傳遞:
python experiment.py --data_dir=$DATA_DIR --train=false --dataset=cifar10 \
--hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \
--summary_dir=/tmp/ --checkpoint=$CKPT_DIR/cifar/cifar{}/model.ckpt-600000 \
--num_trials=7
CIFAR10 訓(xùn)練指令:
python experiment.py --data_dir=$DATA_DIR --dataset=cifar10 --max_steps=600000\
--hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \
--summary_dir=/tmp/
MNIST full 訓(xùn)練指令:
也可以執(zhí)行--validate=true as well 在訓(xùn)練-測試集上訓(xùn)練
執(zhí)行 --num_gpus=NUM_GPUS 在多塊GPU上訓(xùn)練
python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\
--summary_dir=/tmp/attempt0/
MNIST baseline 訓(xùn)練指令:
python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\
--summary_dir=/tmp/attempt1/ --model=baseline
To test on validation during training of the above model:
訓(xùn)練如上模型時,在驗證集上進(jìn)行測試(記住,在訓(xùn)練過程中會持續(xù)執(zhí)行指令):
在訓(xùn)練時執(zhí)行 --validate=true 也一樣
可能需要兩塊 GPU,一塊用于訓(xùn)練集,一塊用于驗證集
如果所有的測試都在一臺機(jī)器上,你需要對訓(xùn)練集、驗證集的測試中限制 RAM 消耗。如果不這樣,TensorFlow 會在一開始占用所有的 RAM,這樣就不能執(zhí)行其他工作了
python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\
--summary_dir=/tmp/attempt0/ --train=false --validate=true
大家可以通過 --num_targets=2 和 --data_dir=$DATA_DIR/multitest_6shifted_mnist.tfrecords@10 在 MultiMNIST 上進(jìn)行測試或訓(xùn)練,生成 multiMNIST/MNIST 記錄的代碼在 input_data/mnist/mnist_shift.py 目錄下。
multiMNIST 測試代碼:
python mnist_shift.py --data_dir=$DATA_DIR/mnist_data/ --split=test --shift=6
--pad=4 --num_pairs=1000 --max_shard=100000 --multi_targets=true
可以通過 --shift=6 --pad=6 來構(gòu)造 affNIST expanded_mnist
論文地址:https://arxiv.org/pdf/1710.09829.pdf
GitHub 地址:https://github.com/Sarasra/models/tree/master/research/capsules
雷鋒網(wǎng) AI 研習(xí)社編譯整理。
(完)
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。