丁香五月天婷婷久久婷婷色综合91|国产传媒自偷自拍|久久影院亚洲精品|国产欧美VA天堂国产美女自慰视屏|免费黄色av网站|婷婷丁香五月激情四射|日韩AV一区二区中文字幕在线观看|亚洲欧美日本性爱|日日噜噜噜夜夜噜噜噜|中文Av日韩一区二区

您正在使用IE低版瀏覽器,為了您的雷峰網(wǎng)賬號安全和更好的產(chǎn)品體驗,強烈建議使用更快更安全的瀏覽器
此為臨時鏈接,僅用于文章預(yù)覽,將在時失效
人工智能開發(fā)者 正文
發(fā)私信給skura
發(fā)送

0

基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

本文作者: skura 2020-01-14 11:43
導(dǎo)語:在概率編程中,JAX 有很多優(yōu)勢

JAX 的表現(xiàn)出乎所有人的意料,在極端情況下,最大性能可提高 20 倍。由于 JAX 的 JIT 編譯開銷,Numpy 在少樣本、少量鏈的情況下會勝出。我報告了 tensorflow probability (TFP) 的結(jié)果,但請記住,這種比較是不公平的,因為它實現(xiàn)的隨機游走 metroplis 比我們的包含更多的功能。

重現(xiàn)結(jié)果所需的代碼可以在這里找到。使代碼運行得更快的技巧值得學(xué)習(xí)。

矢量化 MCMC

Colin Carroll 最近發(fā)布了一篇有趣的博文,使用 Numpy 和隨機游走 metropolis 算法 (RWMH) 的矢量化版本來生成大量的樣本,同時運行多個鏈以便對算法的收斂性進行后驗檢驗。這通常是通過在多線程機器上每個線程運行一個鏈來實現(xiàn)的,在 Python 中使用 joblib 或自定義后端。這么做很麻煩,但它能完成任務(wù)。

Colin 的 文章讓我感到非常興奮,因為我可以在幾乎不增加成本的情況下,同時對成千上萬的鏈進行取樣。他在文章中詳細(xì)介紹了幾個這一方法的應(yīng)用,但我有一種直覺,它可以完成更多的事情。

大約在同一時間,我偶然發(fā)現(xiàn)了 JAX。JAX 在概率編程語言環(huán)境中似乎很有趣,原因如下:

  • 在大多數(shù)情況下,它完全可以替代 Numpy;

  • Autodiff 很簡單;

  • 它的正向微分模式使得計算高階導(dǎo)數(shù)變得容易;

  • JAX 使用 XLA 執(zhí)行 JIT 編譯,即使在 CPU 上也可以加速代碼的運行;

  • 使用 GPU 和 TPU 非常簡單;

  • 這是一個偏好問題,但它更傾向于函數(shù)式編程。

在開始使用 JAX 實現(xiàn)一個框架之前,我想做一些基準(zhǔn)測試,以了解我要注冊的是什么。這里我將進行比較:

  • Numpy

  • Jax

  • Tensorflow Probability (TFP)

  • XLA 編譯的 Tensorflow Probability

關(guān)于基準(zhǔn)測試

在給出結(jié)果之前,首先需要聲明的是:

  1. 報告的時間是在我的筆記本電腦上運行 10 次的平均值,除了終端打開外,沒有任何其它操作。除了編譯后的 JAX 運行外,所有運行的時間都是使用 hyperfine 命令行工具測量的。

  2. 我的代碼可能不是最優(yōu)的,對于 TFP 來說尤其如此。

  3. 實驗是在 CPU 上進行的。JAX 和 TFP 可以運行在 GPU/TPU 上,所以可以期待額外的加速。

  4. 對于 Numpy 和 JAX 來說,采樣器是一個生成器,樣本不保存在內(nèi)存中但對 TFP 來說并非如此,因此在大型實驗期間,計算機會耗盡內(nèi)存。如果 TFP 沒有在堆棧上預(yù)先分配內(nèi)存,不斷地分配內(nèi)存也會影響性能。

  5. 在概率編程中重要的度量是每秒有效采樣的數(shù)量,而不是每秒采樣數(shù)量,前者后者更像是你使用的算法。這個基準(zhǔn)測試仍然可以很好地反映不同框架的原始性能。

設(shè)置和結(jié)果

我在對一個含有 4 個分量的任意高斯混合樣本進行采樣。使用 Numpy:

import numpy as np
from scipy.stats import norm
from scipy.special import logsumexp

def mixture_logpdf(x):
    loc = np.array([[-2, 0, 3.2, 2.5]]).T
    scale = np.array([[1.2, 1, 5, 2.8]]).T
    weights = np.array([[0.2, 0.3, 0.1, 0.4]]).T

    log_probs = norm(loc, scale).logpdf(x)

    return -logsumexp(np.log(weights) - log_probs, axis=0)

Numpy

Colin Carroll 的 MiniMC 是我見過的最簡單、最易讀的大都市隨機游走  Metropolis 和 Hamiltonian Monte Carlo 的實現(xiàn)。我的 Numpy 實現(xiàn)是他的一個迭代:

import numpy as np

def rw_metropolis_sampler(logpdf, initial_position):
    position = initial_position
    log_prob = logpdf(initial_position)
    yield position

    while True:
        move_proposals = np.random.normal(0, 0.1, size=initial_position.shape)
        proposal = position + move_proposals
        proposal_log_prob = logpdf(proposal)

        log_uniform = np.log(np.random.rand(initial_position.shape[0], initial_position.shape[1]))
        do_accept = log_uniform < proposal_log_prob - log_prob

        position = np.where(do_accept, proposal, position)
        log_prob = np.where(do_accept, proposal_log_prob, log_prob)
        yield position

JAX

JAX 的實現(xiàn)與 Numpy 非常相似:

from functools import partial

import jax
import jax.numpy as np

@partial(jax.jit, static_argnums=(0, 1))
def rw_metropolis_kernel(rng_key, logpdf, position, log_prob):
    move_proposals = jax.random.normal(rng_key, shape=position.shape) * 0.1
    proposal = position + move_proposals
    proposal_log_prob = logpdf(proposal)

    log_uniform = np.log(jax.random.uniform(rng_key, shape=position.shape))
    do_accept = log_uniform < proposal_log_prob - log_prob

    position = np.where(do_accept, proposal, position)
    log_prob = np.where(do_accept, proposal_log_prob, log_prob)
    return position, log_prob


def rw_metropolis_sampler(rng_key, logpdf, initial_position):
    position = initial_position
    log_prob = logpdf(initial_position)
    yield position

    while True:
        position, log_prob = rw_metropolis_kernel(rng_key, logpdf, position, log_prob)
        yield position

如果你熟悉 Numpy,那么你應(yīng)該非常熟悉它的語法。JAX 和它有一些不同之處:

  •  jax.numpy 充當(dāng) numpy 的替代。對于只涉及數(shù)組操作的函數(shù),用 import jax.numpy as np 替換 import numpy as np,這會給你帶來性能上的提升。

  • JAX 處理隨機數(shù)生成的方式與其他 Python 包不同,這是有原因的 (請閱讀這篇文章:https://github.com/google/jax/blob/master/design_notes/prng.md ) 。每個發(fā)行版都以一個 PRNG 鍵作為輸入。

  • 因為 JAX 不能編譯生成器,我從采樣器中提取內(nèi)核。因此,我們提取并 JIT 完成所有繁重工作的函數(shù):rw_metropolis_kernel。

  • 我們需要對 JAX 的編譯器提供一點幫助,即指出當(dāng)函數(shù)多次運行時哪些參數(shù)不會改變:@partial(jax.jit, argnums=(0, 1))。如果將函數(shù)作為參數(shù)傳遞,這是必需的,并且可以啟用進一步的編譯時優(yōu)化。

Tensorflow Probability

對于 TFP,我們使用庫中實現(xiàn)的隨機游走 Metropolis 算法:

from functools import partial

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

def run_raw_metropolis(n_dims, n_samples, n_chains, target):
    samples, _ = tfp.mcmc.sample_chain(
        num_results=n_samples,
        current_state=np.zeros((n_dims, n_chains), dtype=np.float32),
        kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob, seed=42),
        num_burnin_steps=0,
        parallel_iterations=8,
    )
    return samples

run_mcm = partial(run_tfp_mcmc, n_dims, n_samples, n_chains, target)

## Without XLA
run_mcm()

## With XLA compilation
tf.xla.experimental.compile(run_mcm)

結(jié)果

我們有兩個自由維度:樣本的數(shù)量和鏈的數(shù)量,第一個依賴于原始的數(shù)字處理能力,第二個也依賴于向量化的實現(xiàn)方式。因此,我決定在兩個維度上對算法進行基準(zhǔn)測試。

我考慮以下情況:

  1. Numpy 實現(xiàn);

  2. JAX 實現(xiàn);

  3. 減去編譯時間的 JAX 實現(xiàn)。這只是一個假設(shè)的情況,目的是顯示編譯帶來的改進。

  4. Tensorflow Probability;

  5. 實驗 XLA 編譯的 Tensorflow Probability。

用 1000 條鏈繪制越來越多的樣本

我們固定鏈的數(shù)量,并改變樣本的數(shù)量。

基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

你將注意到 TFP 實現(xiàn)的缺失點。由于 TFP 算法存儲所有的樣本,所以它會耗盡內(nèi)存。這在 XLA 編譯的版本中沒有發(fā)生,可能是因為它使用了內(nèi)存效率更高的數(shù)據(jù)結(jié)構(gòu)。

對于少于 1000 個樣本,普通的 TFP 和 Numpy 實現(xiàn)比它們的編譯副本要快。這是由于編譯開銷造成的:當(dāng)你減去 JAX 的編譯時間 (從而獲得綠色曲線) 時,它會大大加快速度。只有當(dāng)樣本的數(shù)量變得很大,并且總抽樣時間取決于抽取樣本的時間時,你才開始從編譯中獲益。

沒有什么神奇的:JIT 編譯意味著一個明顯的、但不變的計算開銷。

我建議在大多數(shù)情況下使用 JAX。只有當(dāng)相同的代碼執(zhí)行超過 10 次時,在 0.3 秒而不是 3 秒內(nèi)進行采樣的差異才會產(chǎn)生影響。然而,編譯是只會發(fā)生一次。在這種情況下,計算開銷將在你達(dá)到 10 次迭代之前得到回報。實際上,JAX 贏了。

用越來越多的鏈繪制 1000 個樣本

在這里,我們固定樣本的數(shù)量,改變鏈的數(shù)量。

基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

JAX 仍然明顯地贏了:只要鏈的數(shù)量達(dá)到 10,000,它就比 Numpy 更快。你將注意到 JAX 曲線上有一個凸起,這完全是由于編譯造成的 (綠色曲線沒有這個凸起)。我不知道為什么,如果有答案請告訴我!

這就是令人興奮的亮點:

JAX 可以在 25 秒內(nèi)在 CPU 上生成 10 億個樣本,比 Numpy 快 20 倍!

結(jié)論

對于允許我們用純 python 編寫代碼的項目,JAX 的性能是令人難以置信的。Numpy 仍然是一個不錯的選擇,特別是對于那些 JAX 的大部分執(zhí)行時間都花在編譯上的項目來說尤其如此。

但是,Numpy 不適合概率編程語言。如 Hamiltonian Monte Carlo 這樣的高效抽樣算 Uber 優(yōu)步的團隊開始和 JAX 在 Numpyro 上合作。

不要過多地解讀 Tensorflow Probability 的拙劣表現(xiàn)。當(dāng)從分布中采樣時,重要的不是原始速度,而是每秒有效采樣的數(shù)量。TFP 的實現(xiàn)包括更多的附加功能,我希望它在每秒有效采樣樣本數(shù)方面更具競爭力。

最后,請注意,用鏈的數(shù)量乘以樣本的數(shù)量要比用樣本的數(shù)量乘以樣本的數(shù)量容易得多。我們還不知道如何處理這些鏈,但我有一種直覺,一旦我們這樣做了,概率編程將會有另一個突破。

via:https://rlouf.github.io/post/jax-random-walk-metropolis/

雷鋒網(wǎng)雷鋒網(wǎng)雷鋒網(wǎng)

雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。

基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

分享:
相關(guān)文章
當(dāng)月熱門文章
最新文章
請?zhí)顚懮暾埲速Y料
姓名
電話
郵箱
微信號
作品鏈接
個人簡介
為了您的賬戶安全,請驗證郵箱
您的郵箱還未驗證,完成可獲20積分喲!
請驗證您的郵箱
立即驗證
完善賬號信息
您的賬號已經(jīng)綁定,現(xiàn)在您可以設(shè)置密碼以方便用郵箱登錄
立即設(shè)置 以后再說