前端如何開始深度學習,那不妨試試JAX

語言: CN / TW / HK

一、簡介

在深度學習方面,TensorFlow 和 PyTorch是絕對的王者。但是,但除了這兩個框架之外,一些新生的框架也不容小覷,比如谷歌推出的 JAX深度學習框架。

image.png

1.1、快速發展的JAX

JAX是一個用於高性能數值計算的Python庫,專門為深度學習領域的高性能計算而設計。自2018 年底谷歌的 JAX出現以來,它的受歡迎程度一直在穩步增長,並且越來越多的來自Google 大腦與其他項目也在使用 JAX。隨着JAX越來越火, JAX 似乎正在成為下一代的大型深度學習框架。目前,JAX 在 GitHub 上已累積獲得了超過 19.4K 的關注。

image.png

JAX 是Autograd和XLA的結合,JAX 本身不是一個深度學習的框架,他是一個高性能的數值計算庫,更是結合了可組合的函數轉換庫,用於高性能機器學習研究。深度學習只是其中的一部分而已,但是你完全可以把自己的深度學習移植到JAX 上面。

藉助Autograd的更新版本,JAX 可以自動區分原生 Python 和 NumPy 函數。它可以通過循環、分支、遞歸和閉包進行微分,並且可以對導數的導數進行導數。它支持反向模式微分(也稱為反向傳播)grad和正向模式微分,兩者可以任意組合成任何順序。

説到這,就不得不提 NumPy。NumPy 是 Python 中的一個基礎數值運算庫,被廣泛使用的支持大量的維度數組與矩陣運算的數學函數庫。不過, numpy 本身不支持 GPU 或其他硬件加速器,也沒有對反向傳播的內置支持,此外,Python 本身的速度限制阻礙了 NumPy 使用,所以少有研究者在生產環境下直接用 numpy 訓練或部署深度學習模型。

在此情況下,出現了眾多的深度學習框架,如 PyTorch、TensorFlow 等。但是 numpy 具有靈活、調試方便、API 穩定等獨特的優勢,而 JAX 的主要出發點就是將 numpy 的以上優勢與硬件加速結合,進而支持機器學習研究。除此之外,JAX還具有如下一些優點:

  • 可差分:基於梯度的優化方法在機器學習領域具有十分重要的作用。JAX 可通過grad、hessian、jacfwd 和 jacrev 等函數轉換,原生支持任意數值函數的前向和反向模式的自動微分。
  • 向量化:在機器學習中,通常需要在大規模的數據上運行相同的函數,例如計算整個批次的損失或每個樣本的損失等。JAX 通過 vmap 變換提供了自動矢量化算法,大大簡化了這種類型的計算,這使得研究人員在處理新算法時無需再去處理批量化的問題。JAX 同時還可以通過 pmap 轉換支持大規模的數據並行,從而優雅地將單個處理器無法處理的大數據進行處理。
  • JIT編譯:XLA (Accelerated Linear Algebra, 加速線性代數) 被用於 JIT 即時編譯,在 GPU 和雲 TPU 加速器上執行 JAX 程序。JIT 編譯與 JAX 的 API (與 Numpy 一致的數據函數) 為研發人員提供了便捷接入高性能計算的可能,無需特別的經驗就能將計算運行在多個加速器上。

目前,基於 JAX 已有很多優秀的開源項目,如谷歌的神經網絡庫團隊開發了 Haiku,這是一個面向 Jax 的深度學習代碼庫,通過 Haiku,用户可以在 Jax 上進行面向對象開發;又比如 RLax,這是一個基於 Jax 的強化學習庫,用户使用 RLax 就能進行 Q-learning 模型的搭建和訓練;此外還包括基於 JAX 的深度學習庫 JAXnet,該庫一行代碼就能定義計算圖、可進行 GPU 加速。可以説, JAX其實就是 TensorFlow 的一個簡化庫,支持大部分的TensorFlow 功能,而且比 TensorFlow 更加簡潔易用。

1.2、 JAX 、TensorFlow、PyTorch對比

在深度學習領域,一直都是國外的巨頭公司霸佔着,比如谷歌的TensorFlow、Facebook 的 PyTorch、微軟的 CNTK、亞馬遜 AWS 的 MXnet 等。那他們有什麼特點呢?下面,我們選取JAX 、TensorFlow、PyTorch進行一下對比。

image.png

1.2.1 TensorFlow

TensorFlow 是由谷歌推出的基於數據流編程符號數學系統,被廣泛應用在各類機器學習算法的實現中。具有以下特點: - Tensoflow是一個對用户非常友好的框架。高級 API -Keras 的可用性使模型層定義、損失函數和模型創建變得非常容易。TensorFlow2.0 帶有動態圖類型,使得該庫對用户更加友好,並且是對以前版本的重大升級。 - 由於Keras 的這種高級接口本身的缺陷,所以研究人員在使用自建的模型時自由度降低了。 - TensorFlow 提供的可視化工具包TensorBoard允許用户可視化損失函數、模型圖、分析等,提升了交互體驗。

因此,如果需要使用深度學習或者部署自己的模型,TensorFlow 可能是一個不錯的深度學習框架框架。並且,TensorFlow提供的TensorFlow Lite 版本能將 ML 模型部署到移動和邊緣設備,使得移動設備也能進行深度學習。

1.2.2 PyTorch

PyTorch是由Facebook開源的神經網絡框架,專門針對 GPU 加速的深度神經網絡(DNN)編程。如果説,一兩年前大家談起深度學習還只談起TensorFlow,那麼現在PyTorch也在成為越來越多的開發者的選擇。PyTorch具有如下一些特性: - 與 TensorFlow 不同,PyTorch 使用動態類型圖,這意味着執行圖是隨時隨地創建的,它允許開發者隨時修改和檢查圖的內容。 - 除了用户友好的API 之外,PyTorch 還允許對用户的機器學習模型進行越來越多的自定義控制。這樣一來,我們可以在訓練期間模型的前向和後向傳遞期間檢查和修改輸出。 - PyTorch 允許擴展他們的代碼,輕鬆添加新的損失函數和用户定義的層。PyTorch autograd 足夠強大,可以通過這些用户定義的層進行區分,用户還可以選擇定義梯度的計算方式。 - PyTorch 對數據並行性和 GPU 使用有廣泛的支持。 - PyTorch 比 TensorFlow 更 Pythonic。PyTorch 非常適合 python 生態系統,它允許使用 Python 調試器工具來調試 PyTorch 代碼。

1.2.3 JAX

JAX是一個來自 Google 的機器學習庫,它更像是一個 autograd 庫,可以區分每個本機 python 和 NumPy 代碼。正如我們所看到的,深度學習只是 JAX 功能的一小部分:

image.png

正如官方描述的那樣,JAX 能夠對 Python+NumPy 程序進行可組合的轉換:微分、向量化、JIT 到 GPU/TPU 等等。

下面是JAX的一些特點: - JAX 能夠對 Python+NumPy 程序進行可組合的轉換,比如微分、向量化、JIT 到 GPU/TPU 等等。 - 與 PyTorch 相比,JAX 最重要的方面是梯度計算。在 Torch 中,圖形是在前向傳播期間創建的,而梯度是在後向傳播期間計算的。另一方面,JAX的計算被表示為一個函數,使用方面更友好。 - JAX 是一個 autograd 工具,單獨使用它幾乎不是一個好主意。有各種基於 JAX 的 ML 庫,其中值得注意的是 ObJax、Flax 和 Elegy。由於它們都使用相同的核心,並且接口只是 JAX 庫的包裝器,因此我們將它們放在同一個括號中。

深度學習的成功很大程度上歸功於自動分化。TensorFlow和PyTorch等流行庫在訓練期間跟蹤神經網絡參數的梯度,兩者都包含用於實現深度學習常用神經網絡功能的高級 API。JAX是 CPU、GPU 和 TPU 上的 NumPy,對於高性能機器學習研究具有出色的自動區分能力。除了深度學習框架外,JAX 還創建了一個超級精巧的線性代數庫,具有自動微分和 XLA 支持。不過,JAX目前仍處於起步階段,不建議剛開始探索深度學習的人使用,因為它涉及很多的基礎函數和理論。

二、環境搭建

2.1 Python環境 Mac上搭建Python環境最好的做法是使用Homebrew來安裝,如果你還沒有安裝Python環境,可以使用下面的命令進行安裝。

brew search [email protected] brew install [email protected]

安裝過程中,可能會出現錯誤,比如Error: No such file or directory @ rb_sysopen,如下:

Pouring sqlite-3.38.5.arm64_monterey.bottle.tar.gz Error: No such file or directory @ rb_sysopen - /Users/xzh/Library/Caches/Homebrew/downloads/062e09dc048eab6bed4b64a9ce0533b08d65775640f901d27e24fd4c1ae640d7--sqlite-3.38.5.arm64_monterey.bottle.tar.gz

那麼,我們只需要按照提示,使用brew install 命令單獨安裝sqlite即可。安裝完成之後,再次運行brew install [email protected] 命令即可。安裝完成之後,最好配置下環境變量。

首先,使用open ~/.bash_profile 打開終端工具,然後將下面的代碼複製進去。

```

Setting PATH for Python 3.10

export PATH=${PATH}:/Library/Frameworks/Python.framework/Versions/3.10/bin alias python="/Library/Frameworks/Python.framework/Versions/3.10/bin/python3.10" export PATH=${PATH}:/Library/Frameworks/Python.framework/Versions/3.10/bin alias pip="/Library/Frameworks/Python.framework/Versions/3.10/bin/pip3" export PATH="$PATH:/usr/local/bin/python3.10" ```

配置完成之後,再使用source ~/.bash_profile命令使配置生效。輸入python命令,如果輸出如下信息,則説明Python環境安裝成功。

``` Python 3.10.5 (v3.10.5:f377153967, Jun 6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)] on darwin Type "help", "copyright", "credits" or "license" for more information.

```

2.2 pip工具

pip是 Python包管理工具,提供了對Python 包的查找、下載、安裝、卸載的功能。注需要説明的是,Python 2.7.9 + 或 Python 3.4+ 以上版本都自帶pip工具,如果是最新的版本無需額外安裝。

如果使用pip安裝python插件時,提示command not found錯誤,可以證明你還沒有安裝pip工具,可以使用下面的命令進行安裝。

curl https://bootstrap.pypa.io/get-pip.py | python3

同時,在採用默認 pip3 安裝第三方庫的時候,經常會出現超時的情況。

pip._vendor.urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='files.pythonhosted.org', port=443): Read timed out.

對於這種異常,可以使用國內的鏡像源,比如: - 阿里雲:https://mirrors.aliyun.com/pypi/simple/ - 清華:https://pypi.tuna.tsinghua.edu.cn/simple - 中國科技大學: https://pypi.mirrors.ustc.edu.cn/simple/

當然,我們還可以打開 ~/.pip/pip.conf文件創建自己的配置文件,比如:

mkdir -p ~/.pip cat > ~/.pip/pip.conf<<eof [global] timeout = 6000 index-url = https://mirrors.aliyun.com/pypi/simple/ trusted-host = mirrors.aliyun.com eof

然後,執行安裝如果有下面的提示,則説明鏡像源已被替換。

Looking in indexes: https://mirrors.aliyun.com/pypi/simple/

2.3 JAX基本使用

2.3.1 JAX插件安裝

經過前面的介紹,我們知道,jax其實就是一個函數庫,所以我們使用之前,需要先安裝一下jax插件,安裝時需要使用pip命令安裝,命令如下:

pip install --upgrade pip pip install --upgrade "jax[cpu]"

關於如何安裝,大家可以參考下官方文檔的介紹。安裝成功之後,會給出成功的提示,如下圖。

image.png

除了CPU版本外,JAX還支持GPU和TPU,安裝的命令如下:

//GPU pip install --upgrade pip pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html //TPU pip install --upgrade pip pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

2.3.2 官方示例工程

JAX 的定位是科學計算(Scientific Computing)和函數轉換(Function Transformations),具有除訓練深度學習模型以外的一系列能力,具體包括: - 即時編譯(Just-in-Time Compilation) - 自動並行化(Automatic Parallelization) - 自動向量化(Automatic Vectorization) - 自動微分(Automatic Differentiation)

為了方便學習,我們最好下載下官網的示例工程並運行,示例工程代碼結構如下:

image.png

2.3.3 在線編程平台jupyter

Jupyter 是一款開放性的代碼編寫軟件,最大的特點是能夠實時運行代碼,查看輸出效果。同時該軟件集成了多種插件,能夠實現非常複雜的功能。對於初學者而言,因為Jupyter具有良好的交互性,方便看到每一行、每一個代碼塊的輸出結果,Jupyter成為眾多Python初學者的青睞。

image.png

同時,Jupyter提供了用 JupyterLab 和 Jupyter Notebook 等交互式編寫軟件的技術方式,能夠更好的幫助開發者編寫、運行代碼。其中,安裝JupyterLab命令如下:

pip3 install jupyterlab //啟動命令 jupyter-lab

安裝完成之後,在使用jupyter-lab即可啟動。如果是安裝Jupyter Notebook,那麼安裝的命令如下:

pip3 install notebook //啟動命令 jupyter notebook

啟動成功之後,會自動打開http://localhost:8888/tree頁面,如下圖。

image.png

然後,我們點擊右上角的【新建】按鈕新建一個運行面板,如下圖。

image.png

接下來,我們就可以在上面運行一些函數。當然,Jupyter還提供了在線編輯運行平台,可以幫助開發者快速的體驗Jupyter的魅力,目前支持主流的編程語言和技術。

image.png

image.png

如果要運行項目,可以使用【Ctrol】+回車即可得到運行結果。比如生成隨機數據:

key = random.PRNGKey(0) x = random.normal(key, (10,)) print(x)

當我們【Ctrol】+回車運行項目時,得到的結果如下:

[-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442 -0.67135346 -0.5908641 0.73168886 0.5673026 ]

其中,比較常用的快捷鍵有如下一些: - Tab : 代碼補全或縮進 - Shift-Tab : 提示 - Ctrl-A : 全選 - Ctrl-Z : 復原 - Ctrl-Shift-Z : 再做 - Ctrl-Y : 再做 - Ctrl-Backspace : 刪除前面一個字 - Ctrl-Delete : 刪除後面一個字 - Esc : 進入命令模式 - Ctrl-M : 進入命令模式 - Shift-Enter : 運行本單元,選中下一單元 - Ctrl-Enter : 運行本單元 - Alt-Enter : 運行本單元,在下面插入一單元 - Ctrl-Shift-- : 分割單元 - Ctrl-Shift-Subtract : 分割單元 - Ctrl-S : 文件存盤 - Shift : 忽略 其他常用的快捷鍵,可以通過幫助選項來進行查看,如下圖。

image.png

2.3.4 基本使用

隨機函數

和編寫其他python的語法一樣,使用Jax之前需要導入相關的函數包。比如,我們使用NumPy執行一些基準測試,比如:

import jax import jax.numpy as jnp from jax import random from jax import grad, jit import numpy as np key = random.PRNGKey(0)

當然,我們也可以 import jax.numpy as jnp 並將代碼中的所有 np 替換為 jnp 。與NumPy 代碼風格不同,在JAX 代碼中,可以直接使用import方式導入並直接使用。可以看到,JAX 中隨機數的生成方式與 NumPy 不同。JAX需要創建一個 jax.random.PRNGKey 。

矩陣乘法

我們在 Google Colab 上做一個簡單的基準測試,這樣我們就可以輕鬆訪問 GPU 和 TPU。我們首先初始化一個包含 25M 元素的隨機矩陣,然後將其乘以它的轉置,使用針對 CPU 優化的 NumPy,矩陣乘法平均需要408 ms ± 35.9 ms。

``` size = 5000 x = np.random.normal(size=(size, size)).astype(np.float32) %timeit np.dot(x, x.T)

408 ms ± 35.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

```

在 CPU 上使用 JAX 執行相同的操作平均需要大約 716 ms ± 13.7 ms。

``` size = 5000 x = random.normal(key, (size, size), dtype=jnp.float32) %timeit jnp.dot(x, x.T).block_until_ready()

716 ms ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

```

在 CPU 上運行時,JAX 通常比 NumPy 慢,因為 NumPy 已針對CPU進行了非常多的優化。但是,當使用加速器時這種情況會發生變化,所以讓我們嘗試使用 GPU 進行矩陣乘法。

``` size = 5000 x = random.normal(key, (size, size), dtype=jnp.float32) %time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()

CPU times: user 50 µs, sys: 1 µs, total: 51 µs

Wall time: 53.9 µs

CPU times: user 5.14 s, sys: 44.9 ms, total: 5.19 s

Wall time: 732 ms

725 ms ± 25.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

```

接下來,讓我們使用 TPU 來進行矩陣乘法。

``` size = 5000 x = random.normal(key, (size, size), dtype=jnp.float32) %time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()

CPU times: user 54 µs, sys: 916 µs, total: 970 µs

Wall time: 973 µs

CPU times: user 5.25 s, sys: 34.2 ms, total: 5.28 s

Wall time: 709 ms

715 ms ± 3.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

```

可以看到,忽略設備傳輸時間和編譯時間,每個矩陣乘法平均需要715 ms ± 3.95 毫秒,與GPU 相比,TPU快了差不多4倍。需要説明的是,當乘以不同大小的矩陣時,獲得相同的加速效果也不同:相乘的矩陣越大,GPU可以優化操作的越多,加速也越大。

jit()

JAX在GPU上是透明運行的。但是,在上面的示例中,JAX一次將內核分配給GPU一次操作,如果我們有一系列操作,則可以使用@jit裝飾器使用XLA一起編譯多個操作。

``` def selu(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,)) %timeit selu(x).block_until_ready()

1.64 ms ± 91 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

```

可以使用加快速度@jit,它將在第一次selu調用jit-compile並將其之後緩存。

``` def selu(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jit(selu) %timeit selu_jit(x).block_until_ready()

455 µs ± 151 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

```

可以看到,使用jit裝飾器後,運行效率明顯提高。

grad()

除了評估數值函數外,我們還希望對值進行轉換,其中一種轉變是自動微分。在JAX中,就像在Autograd中一樣,可以使用grad()函數來進行梯度計算。

``` def sum_logistic(x): return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.) derivative_fn = grad(sum_logistic) print(derivative_fn(x_small))

[0.25 0.19661197 0.10499357]

```

接下來,讓我們使用極限微分來驗證我們的結果是否正確。

``` def first_finite_differences(f, x): eps = 1e-3 return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569 0.10502338]

```

求解梯度可以通過簡單調用grad()。grad()並jit()可以任意混合。在上面的示例中,我們先抖動sum_logistic然後取其派生詞。

``` def first_finite_differences(f, x): eps = 1e-3 return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))])

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

//-0.035325598 ```

vmap()

JAX在其API中還有另一種轉換,那就是vmap()向量化映射。它具有沿數組軸映射函數的熟悉語義,但不是將循環保留在外部,而是將循環推入函數的原始操作中以提高性能。當與組合時jit(),它的速度可以與手動添加批處理尺寸一樣快。

``` mat = random.normal(key, (150, 100)) batched_x = random.normal(key, (10, 100))

def apply_matrix(v): return jnp.dot(mat, v)

print(apply_matrix(100)) ```

給定功能apply_matrix,然後在Python中循環執行批處理維度,但是這樣做的性能通常很差。

``` def naively_batched_apply_matrix(v_batched): return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched

433 µs ± 2.02 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

```

如果,我們使用vmap()自動添加批處理支持,那效率就提高不少。

``` @jit def vmap_batched_apply_matrix(v_batched): return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap') %timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap

13.5 µs ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

```

事實上,vmap()可以與任意組成jit(),grad()和任何其它JAX變換。當然,JAX的函數還有很多,大家可以查看官方資料進行學習。

三、XLA架構

XLA 是 JAX(和其他庫,例如 TensorFlow,TPU的Pytorch)使用的線性代數的編譯器,它通過創建自定義優化內核來保證最快的在程序中運行線性代數運算。XLA 最大的好處是可以讓我們在應用中自定義內核,該部分使用線性代數運算,以便它可以進行最多的優化。

在TensorFlow中,XLA給TensorFlow帶來了如下提升: - 提高執行速度。編譯子計算圖以減少短暫運算的執行時間,從而消除 TensorFlow 運行時的開銷;融合流水線運算以降低內存開銷;並針對已知張量形狀執行專門優化以支持更積極的常量傳播。 - 提高內存使用率。分析和安排內存使用量,原則上需要消除許多中間存儲緩衝區。 - 降低對自定義運算的依賴。通過提高自動融合的低級運算的性能,使之達到手動融合的自定義運算的性能水平,從而消除對多種自定義運算的需求。 - 減少移動資源佔用量。通過提前編譯子計算圖併發出可以直接鏈接到其他應用的對象/頭文件對,消除 TensorFlow 運行時。這樣,移動推斷的資源佔用量可降低幾個數量級。 - 提高便攜性。使針對新穎硬件編寫新後端的工作變得相對容易,在新硬件上運行時,大部分 TensorFlow 程序都能夠以未經修改的方式運行。與針對新硬件專門設計各個整體運算的方式相比,這種模式不必重新編寫 TensorFlow 程序即可有效利用這些運算。

不過,XLA 最重要的優化是融合,即可以在同一個內核中進行多個線性代數運算,將中間輸出保存到 GPU 寄存器中,而不將它們具體化到內存中。這可以顯着增加我們的“計算強度”,即所做的工作量與負載和存儲數量的比例。融合還可以讓我們完全省略僅在內存中shuffle 的操作(例如reshape)。

下面我們看看如何使用 XLA 和 jax.jit 手動觸發 JIT 編譯。

使用 jax.jit 進行即時編譯

這裏有一些新的基準來測試 jax.jit 的性能。我們定義了兩個實現 SELU(Scaled Exponential Linear Unit)的函數:一個使用 NumPy,一個使用 JAX。

``` def selu_np(x, alpha=1.67, lmbda=1.05): return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha) def selu_jax(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = np.random.normal(size=(1000000,)).astype(np.float32) %timeit selu_np(x)

7.56 ms ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

```

可以看到,NumPy平均需要 7.6 毫秒。接下來,讓我們在 CPU 上使用 JAX運行,如下。

``` def selu_np(x, alpha=1.67, lmbda=1.05): return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha) def selu_jax(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,)) %time selu_jax(x).block_until_ready()
%timeit selu_jax(x).block_until_ready()

CPU times: user 5.27 ms, sys: 2.7 ms, total: 7.97 ms

Wall time: 3.57 ms

1.7 ms ± 94.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

```

可以看到,這種情況下,JAX明顯要比 NumPy 快。下一個測試是在 GPU 上使用 JAX。

``` x = random.normal(key, (1000000,)) %time x_jax = jax.device_put(x)
%time selu_jax(x_jax).block_until_ready()
%timeit selu_jax(x_jax).block_until_ready()

CPU times: user 54 µs, sys: 39 µs, total: 93 µs

Wall time: 96.1 µs

CPU times: user 2.27 ms, sys: 1.3 ms, total: 3.57 ms

Wall time: 1.71 ms

1.63 ms ± 45.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

```

可以看到,函數運行時間為1.63毫秒。下面我們用 jax.jit 測試它,觸發 JIT 編譯器使用 XLA 將 SELU 函數編譯到優化的 GPU 內核中,同時優化函數內部的所有操作。

``` x = random.normal(key, (1000000,)) selu_jax_jit = jit(selu_jax) %time x_jax = jax.device_put(x)
%time selu_jax_jit(x_jax).block_until_ready()
%timeit selu_jax_jit(x_jax).block_until_ready()

CPU times: user 114 µs, sys: 305 µs, total: 419 µs

Wall time: 426 µs

CPU times: user 30.2 ms, sys: 7.86 ms, total: 38.1 ms

Wall time: 36.5 ms

361 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

```

可以看到,使用編譯內核,函數運行時間為0.36毫秒。之所以能帶來如此大的性能提升,是因為使用 JIT 編譯避免從 GPU 寄存器中移動數據,從未帶來了非常大的加速。一般來説在不同類型的內存之間移動數據與代碼執行相比非常慢,因此在實際使用時應該儘量避免。

將 SELU 函數應用於不同大小的向量時,您可能會獲得不同的結果。矢量越大,加速器越能優化操作,加速也越大。除了執行 selu_jax_jit = jit(selu_jax) 之外,還可以使用 @jit 裝飾器對函數進行 JIT 編譯,如下所示。

@jit def selu_jax_jit(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

JIT 編譯可以加速,為什麼我們不能全部都這樣做呢?因為並非所有代碼都可以 JIT 編譯,JIT要求數組形狀是靜態的並且在編譯時已知。另外就是引入jax.jit 也會帶來一些開銷。因此通常只有編譯的函數比較複雜並且需要多次運行才能節省時間。

參考鏈接:

https://jax.readthedocs.io/en/latest/installation.html
https://github.com/google/jax