0
JAX 的表現(xiàn)出乎所有人的意料,在極端情況下,最大性能可提高 20 倍。由于 JAX 的 JIT 編譯開銷,Numpy 在少樣本、少量鏈的情況下會勝出。我報告了 tensorflow probability (TFP) 的結(jié)果,但請記住,這種比較是不公平的,因為它實現(xiàn)的隨機游走 metroplis 比我們的包含更多的功能。
重現(xiàn)結(jié)果所需的代碼可以在這里找到。使代碼運行得更快的技巧值得學(xué)習(xí)。
矢量化 MCMC
Colin Carroll 最近發(fā)布了一篇有趣的博文,使用 Numpy 和隨機游走 metropolis 算法 (RWMH) 的矢量化版本來生成大量的樣本,同時運行多個鏈以便對算法的收斂性進行后驗檢驗。這通常是通過在多線程機器上每個線程運行一個鏈來實現(xiàn)的,在 Python 中使用 joblib 或自定義后端。這么做很麻煩,但它能完成任務(wù)。
Colin 的 文章讓我感到非常興奮,因為我可以在幾乎不增加成本的情況下,同時對成千上萬的鏈進行取樣。他在文章中詳細(xì)介紹了幾個這一方法的應(yīng)用,但我有一種直覺,它可以完成更多的事情。
大約在同一時間,我偶然發(fā)現(xiàn)了 JAX。JAX 在概率編程語言環(huán)境中似乎很有趣,原因如下:
在大多數(shù)情況下,它完全可以替代 Numpy;
Autodiff 很簡單;
它的正向微分模式使得計算高階導(dǎo)數(shù)變得容易;
JAX 使用 XLA 執(zhí)行 JIT 編譯,即使在 CPU 上也可以加速代碼的運行;
使用 GPU 和 TPU 非常簡單;
這是一個偏好問題,但它更傾向于函數(shù)式編程。
在開始使用 JAX 實現(xiàn)一個框架之前,我想做一些基準(zhǔn)測試,以了解我要注冊的是什么。這里我將進行比較:
Numpy
Jax
Tensorflow Probability (TFP)
XLA 編譯的 Tensorflow Probability
關(guān)于基準(zhǔn)測試
在給出結(jié)果之前,首先需要聲明的是:
報告的時間是在我的筆記本電腦上運行 10 次的平均值,除了終端打開外,沒有任何其它操作。除了編譯后的 JAX 運行外,所有運行的時間都是使用 hyperfine 命令行工具測量的。
我的代碼可能不是最優(yōu)的,對于 TFP 來說尤其如此。
實驗是在 CPU 上進行的。JAX 和 TFP 可以運行在 GPU/TPU 上,所以可以期待額外的加速。
對于 Numpy 和 JAX 來說,采樣器是一個生成器,樣本不保存在內(nèi)存中但對 TFP 來說并非如此,因此在大型實驗期間,計算機會耗盡內(nèi)存。如果 TFP 沒有在堆棧上預(yù)先分配內(nèi)存,不斷地分配內(nèi)存也會影響性能。
在概率編程中重要的度量是每秒有效采樣的數(shù)量,而不是每秒采樣數(shù)量,前者后者更像是你使用的算法。這個基準(zhǔn)測試仍然可以很好地反映不同框架的原始性能。
設(shè)置和結(jié)果
我在對一個含有 4 個分量的任意高斯混合樣本進行采樣。使用 Numpy:
import numpy as np
from scipy.stats import norm
from scipy.special import logsumexp
def mixture_logpdf(x):
loc = np.array([[-2, 0, 3.2, 2.5]]).T
scale = np.array([[1.2, 1, 5, 2.8]]).T
weights = np.array([[0.2, 0.3, 0.1, 0.4]]).T
log_probs = norm(loc, scale).logpdf(x)
return -logsumexp(np.log(weights) - log_probs, axis=0)
Numpy
Colin Carroll 的 MiniMC 是我見過的最簡單、最易讀的大都市隨機游走 Metropolis 和 Hamiltonian Monte Carlo 的實現(xiàn)。我的 Numpy 實現(xiàn)是他的一個迭代:
import numpy as np
def rw_metropolis_sampler(logpdf, initial_position):
position = initial_position
log_prob = logpdf(initial_position)
yield position
while True:
move_proposals = np.random.normal(0, 0.1, size=initial_position.shape)
proposal = position + move_proposals
proposal_log_prob = logpdf(proposal)
log_uniform = np.log(np.random.rand(initial_position.shape[0], initial_position.shape[1]))
do_accept = log_uniform < proposal_log_prob - log_prob
position = np.where(do_accept, proposal, position)
log_prob = np.where(do_accept, proposal_log_prob, log_prob)
yield position
JAX
JAX 的實現(xiàn)與 Numpy 非常相似:
from functools import partial
import jax
import jax.numpy as np
@partial(jax.jit, static_argnums=(0, 1))
def rw_metropolis_kernel(rng_key, logpdf, position, log_prob):
move_proposals = jax.random.normal(rng_key, shape=position.shape) * 0.1
proposal = position + move_proposals
proposal_log_prob = logpdf(proposal)
log_uniform = np.log(jax.random.uniform(rng_key, shape=position.shape))
do_accept = log_uniform < proposal_log_prob - log_prob
position = np.where(do_accept, proposal, position)
log_prob = np.where(do_accept, proposal_log_prob, log_prob)
return position, log_prob
def rw_metropolis_sampler(rng_key, logpdf, initial_position):
position = initial_position
log_prob = logpdf(initial_position)
yield position
while True:
position, log_prob = rw_metropolis_kernel(rng_key, logpdf, position, log_prob)
yield position
如果你熟悉 Numpy,那么你應(yīng)該非常熟悉它的語法。JAX 和它有一些不同之處:
jax.numpy 充當(dāng) numpy 的替代。對于只涉及數(shù)組操作的函數(shù),用 import jax.numpy as np 替換 import numpy as np,這會給你帶來性能上的提升。
JAX 處理隨機數(shù)生成的方式與其他 Python 包不同,這是有原因的 (請閱讀這篇文章:https://github.com/google/jax/blob/master/design_notes/prng.md ) 。每個發(fā)行版都以一個 PRNG 鍵作為輸入。
因為 JAX 不能編譯生成器,我從采樣器中提取內(nèi)核。因此,我們提取并 JIT 完成所有繁重工作的函數(shù):rw_metropolis_kernel。
我們需要對 JAX 的編譯器提供一點幫助,即指出當(dāng)函數(shù)多次運行時哪些參數(shù)不會改變:@partial(jax.jit, argnums=(0, 1))。如果將函數(shù)作為參數(shù)傳遞,這是必需的,并且可以啟用進一步的編譯時優(yōu)化。
Tensorflow Probability
對于 TFP,我們使用庫中實現(xiàn)的隨機游走 Metropolis 算法:
from functools import partial
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
def run_raw_metropolis(n_dims, n_samples, n_chains, target):
samples, _ = tfp.mcmc.sample_chain(
num_results=n_samples,
current_state=np.zeros((n_dims, n_chains), dtype=np.float32),
kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob, seed=42),
num_burnin_steps=0,
parallel_iterations=8,
)
return samples
run_mcm = partial(run_tfp_mcmc, n_dims, n_samples, n_chains, target)
## Without XLA
run_mcm()
## With XLA compilation
tf.xla.experimental.compile(run_mcm)
結(jié)果
我們有兩個自由維度:樣本的數(shù)量和鏈的數(shù)量,第一個依賴于原始的數(shù)字處理能力,第二個也依賴于向量化的實現(xiàn)方式。因此,我決定在兩個維度上對算法進行基準(zhǔn)測試。
我考慮以下情況:
Numpy 實現(xiàn);
JAX 實現(xiàn);
減去編譯時間的 JAX 實現(xiàn)。這只是一個假設(shè)的情況,目的是顯示編譯帶來的改進。
Tensorflow Probability;
實驗 XLA 編譯的 Tensorflow Probability。
用 1000 條鏈繪制越來越多的樣本
我們固定鏈的數(shù)量,并改變樣本的數(shù)量。

你將注意到 TFP 實現(xiàn)的缺失點。由于 TFP 算法存儲所有的樣本,所以它會耗盡內(nèi)存。這在 XLA 編譯的版本中沒有發(fā)生,可能是因為它使用了內(nèi)存效率更高的數(shù)據(jù)結(jié)構(gòu)。
對于少于 1000 個樣本,普通的 TFP 和 Numpy 實現(xiàn)比它們的編譯副本要快。這是由于編譯開銷造成的:當(dāng)你減去 JAX 的編譯時間 (從而獲得綠色曲線) 時,它會大大加快速度。只有當(dāng)樣本的數(shù)量變得很大,并且總抽樣時間取決于抽取樣本的時間時,你才開始從編譯中獲益。
沒有什么神奇的:JIT 編譯意味著一個明顯的、但不變的計算開銷。
我建議在大多數(shù)情況下使用 JAX。只有當(dāng)相同的代碼執(zhí)行超過 10 次時,在 0.3 秒而不是 3 秒內(nèi)進行采樣的差異才會產(chǎn)生影響。然而,編譯是只會發(fā)生一次。在這種情況下,計算開銷將在你達(dá)到 10 次迭代之前得到回報。實際上,JAX 贏了。
用越來越多的鏈繪制 1000 個樣本
在這里,我們固定樣本的數(shù)量,改變鏈的數(shù)量。

JAX 仍然明顯地贏了:只要鏈的數(shù)量達(dá)到 10,000,它就比 Numpy 更快。你將注意到 JAX 曲線上有一個凸起,這完全是由于編譯造成的 (綠色曲線沒有這個凸起)。我不知道為什么,如果有答案請告訴我!
這就是令人興奮的亮點:
JAX 可以在 25 秒內(nèi)在 CPU 上生成 10 億個樣本,比 Numpy 快 20 倍!
結(jié)論
對于允許我們用純 python 編寫代碼的項目,JAX 的性能是令人難以置信的。Numpy 仍然是一個不錯的選擇,特別是對于那些 JAX 的大部分執(zhí)行時間都花在編譯上的項目來說尤其如此。
但是,Numpy 不適合概率編程語言。如 Hamiltonian Monte Carlo 這樣的高效抽樣算 Uber 優(yōu)步的團隊開始和 JAX 在 Numpyro 上合作。
不要過多地解讀 Tensorflow Probability 的拙劣表現(xiàn)。當(dāng)從分布中采樣時,重要的不是原始速度,而是每秒有效采樣的數(shù)量。TFP 的實現(xiàn)包括更多的附加功能,我希望它在每秒有效采樣樣本數(shù)方面更具競爭力。
最后,請注意,用鏈的數(shù)量乘以樣本的數(shù)量要比用樣本的數(shù)量乘以樣本的數(shù)量容易得多。我們還不知道如何處理這些鏈,但我有一種直覺,一旦我們這樣做了,概率編程將會有另一個突破。
via:https://rlouf.github.io/post/jax-random-walk-metropolis/
雷鋒網(wǎng)雷鋒網(wǎng)雷鋒網(wǎng)
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。