Python(JAX) と Julia の比較

Pythonは読み書きしやすいプログラミング言語ですが、計算速度が遅いことが玉に瑕です。 これを克服する選択肢として、JAXJuliaを比較しました。

なぜPython(JAX)とJulia?

候補は以下のものがありました。

  • Python

    • NumPy & SciPy: コンパイルされたCコードを実行する技術計算ライブラリ

    • CuPy: NumPy、SciPyをGPUで実行

    • cuNumeric: NumPyをGPUで実行。マルチノード対応

    • Numba: Python関数をJITコンパイル

    • JAX: Python関数をJITコンパイル。自動微分に対応しTensorFlowPyTorchに続く機械学習ライブラリとして期待

  • Julia: Pythonのように書けて、計算速度はC言語とPythonの間にある[BKSE12]

  • MATLAB: 有償ソフトウェア。製品設計、制御設計のプラットフォーム

  • C++: GitHubで見かける最近の多くの技術計算オープンソースソフトウェアがC++で記述され、Pythonインターフェースで実行される

  • Fortran: 技術計算言語として長く利用されている。例えばBLASの他の多くのライブラリのバックエンド

  • Rust: 実行速度が速く、デバッグしやすい。

Pythonライブラリの中で、JAXは、NumPy・SciPy、GPU計算、JITコンパイル、自動微分のすべてをカバーする(おそらく唯一)のライブラリです。 一方、Juliaは、前述の機能をカバーした上で、技術計算に特化しています。Pythonと同じくJupyter環境で逐次実行できます。 MATLABは、JITコンパイル可能なようですが、ライセンスを購入しなければ使用できません。 C++、Fortranは、習熟やデバッグに時間がかかりそうなので避けました。さらなる高速が必要な時に、再度検討します。 Rustはデバッグしやすいらしいですが、技術計算ライブラリが充実しているようには見えませんでした。 そういうわけで、今回はJAXとJuliaを選択しました。

検証環境

Windows11のUbuntu 20.04 on WSL2で実行しました。Windows10 ver. 20H2以降は、WSLでGPUを利用できます。

CPUとGPUは、以下の表のものを用いました。

浮動小数点演算性能 (単位: TFLOPS)

FP32

FP64

CPU: AMD Ryzen9 5950X

0.97

GPU: NVIDIA GeForce RTX 3090

39.1

(0.61)

上記数値はベンチマークサイトを参考にしました。

どちらも2020年秋発売のメインストリーム向けフラグシップなので、CPU計算とGPU計算の比較には適しています。

インストール

JAX

公式のインストール手順にしたがってインストールしました。 WindowsでGPUバージョンのJAXを利用する場合、CUDA on WSLを利用します。

Attention

WSLではないWindows上で実行したい場合、ソースからビルドする必要があります。

pipパッケージに、CUDAとCuDNNがバンドルされていないので、別途インストールが必要です。 CUDAはここ、CuDNNは、ここに記載された通りの手順で導入しました。後述のJuliaに比べると面倒です。

Caution

JAX、CUDA、CuDNNのバージョンを一致させなければ動作しません。

Julia

Julia のインストールは、公式からダウンロードして、解凍してPathを通すだけです。

GPU計算(CUDA)の利用には、CUDA.jlを利用します。 Julia のコマンドプロンプトで以下のコマンド1行だけでセットアップ完了です。

julia> using Pkg; Pkg.add("CUDA")

手順がシンプルで、WindowsとLinuxで導入手順が共通しています。

読み書きしやさ

JAX

JITコンパイルはデコレータ@jitで指示します。 インスタンス変数に配列を持つclass内methodのJITコンパイルにはやや面倒な対策が必要でした(公式FAQ)。 Pythonらしさはなくなりますが、classを使わずに書くのが簡単そうです。 また、jax.numpyではない通常のnumpyがあるとJITコンパイルに失敗するので、膨大なライブラリがあるというPythonの利点を活かせないこともありそうです。

Julia

classの概念がなく、function (Pythonでいうmethod) で記述します。 JITコンパイルはユーザーが明示的に指示する必要がありません。

しかし、Pythonに比べると、APIドキュメントの情報量が少なく、ライブラリ使い方の理解に時間がかかることが多い印象です。

計算時間の比較

比較条件

以下のポアソン方程式を周期境界条件で解きました。

\[ \frac{\partial^2 p}{\partial x^2} + \frac{\partial^2 p}{\partial y^2} = \frac{\partial u}{\partial x} + \frac{\partial v}{\partial y} \]
\[\begin{split} u = \sin 2 x \\ v = \sin 2 y \end{split}\]

ソースコードは、JAXJuliaの通りです。 差文法の実装の違い、線形代数ソルバーの違いがあるために、厳密な比較にはなっていません。

比較結果

JIT-compileされたJAXとJuliaが非常に高速であることがわかります。 配列サイズ10万を堺にCPUとGPUの速度が逆転することがわかります。 配列サイズが大きい場合には、CPUとGPUの差はベンチマークサイトの結果通りでした。 GPU計算は自動的に単精度実数として実行されるようです。

elaplsed_time

まとめ

JITコンパイルJAXとJuliaが非JITコンパイルJAXに比べて非常に高速であることがわかりました。 非JITコンパイルJAXが通常のPythonと比べてどうなのか、同じアルゴリズムの時にJITコンパイルJAXとJuliaの差はどの程度かは時間があれば調べたいです。

導入しやすさと実装しやすさは、Juliaが明らかに勝っています。Pythonでなければならない理由がない限りは、私はJuliaを選択します。

BKSE12

Jeff Bezanson, Stefan Karpinski, Viral B. Shah, and Alan Edelman. Julia: a fast dynamic language for technical computing. 2012. URL: https://arxiv.org/abs/1209.5145, doi:10.48550/ARXIV.1209.5145.