JAX-CFDをWindowsのAnacondaにインストールする
By K.Yoshimi
JAX-CFDは、Googleが開発している、数値流体力学における機械学習、自動微分、ハードウェアアクセラレータ(GPU/TPU)の可能性を探るための実験研究プロジェクトで、 JAXで実装されています。
今のところ、非定常の乱流計算に焦点が当てられていますが、定常のRANS計算なども今後、実装されていく予定とのことです。
将来的には、jax-cfdの機能は3次元に拡張され、ANSYS Fluentのソルバーの高速化のために導入されるとの情報もあります。
今回は、JAX-CFDをWindowsのAnaconda上にインストールします。
Visual Studoのインストール
JAX-CFDをAnaconda上で使用するためには、Jaxlibをソースコードからビルドする必要があるため、C++コンパイラーが必要です。
ここでは、Visual Studio 2019をインストールしました。
Cuda,cuDNNのインストール
私のマシンに入っている、NVIDIA GeForce RTX 2080 Tiに対応している
CUDA Toolkit 11.6
cuDNN 8.3.2
をインストールしました。
ここで、システム環境変数のPathに
C:\PROGRA~1\NVIDIA~2\CUDA\v11.6\bin
C:\PROGRA~1\NVIDIA~2\CUDA\v11.6\libnvvp
があることを確認し、無い場合は追加します。
仮想環境の作成
まず、AnacondaにJAX-CFDの作業をするための仮想環境を追加します。
Anaconda Promptを立ち上げ、下記のコマンドを実行すると、新しい仮想環境jax_cfdが追加されます。
ここでは、pythonのバージョンとして、3.8を指定しています。
(base) C:/Users/(ユーザ名)>conda create -n jax_cfd python=3.8
実際に、新しい仮想環境が追加されたことを確認するために、conda info -eを実行しますと、jax_cfdを確認できます。
(base) C:/Users/(ユーザ名)>conda info -e
# conda environments:
#
base * C:/Users/(ユーザ名)/Anaconda3
jax_cfd C:/Users/(ユーザ名)/Anaconda3/envs/jax_cfd
また、追加した仮想環境jax_cfdに移動するために、activateコマンドを使用します。
(base) C:/Users/(ユーザ名)>activate jax_cfd
(jax_cfd) C:/Users/(ユーザ名)>
下記ビルド時にnumpyモジュールが要求されますので、先に、numpyをインストールしておきます。
(jax_cfd) C:/Users/(ユーザ名)>pip install numpy
JAXのインストール
次に、JAXをインストールします。
JAXは、LinuxやmacOSにはバイナリが用意されていますが、Windows用は用意されておりませんので、ソースコードから生成する必要があります。
MSYS2のインストール
まず、ソースのビルドに必要なコマンドを導入します。そのために、MSYS2をインストールします。
リンク先から、インストーラーをダウンロードし、指示通りに実行してインストールします。
インストールしたら、MSYS2のシェル上で、下記のコマンドを実行します。
pacman -S patch coreutils
これにより、patchやrealpathといったコマンドが使用できるようになります。
続けて、システム環境変数の編集でインストールしたMSYS2のbinをパスに追加します。
下図のように
システムのプロパティ ⇒ 環境変数(N)…
と進み、システム環境変数の中にあるPathに対して、インストールしたMSYS2のbin
C:/msys64/usr/bin
を追加します。
以上により、Anaconda Promptからも、これらのコマンドを利用できるようになります。
JAXのソースコードをダウンロード
JAXのソースコードをリンク先から、適当な場所に、ダウンロードあるいは、Gitでクローンします。
Jaxlibのビルド
Anaconda Promptを一旦閉じてから、再び開き、仮想環境jax_cfdに移動します。
※ システム環境変数の変更を有効にするためには、Promptの再起動が必要です。
Prompt上で、patchやrealpathのコマンドをタイプし、存在していることを確認しましょう。
それでは、ダウンロードしたJAXのフォルダに移動しましょう。
(jax_cfd) C:/Users/(ユーザ名)> cd C:/(ダウンロード先)/jax
使用するCUDAやcuDNNのインストール先や、バージョンを指定して、下記のようなコマンドでビルドを実行します。
python build/build.py --enable_cuda^
--cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.6"^
--cudnn_path="C:/Program Files/NVIDIA/CUDNN/v8.3"^
--cuda_version="11.6"^
--cudnn_version="8.3.2"
ビルドには、少々時間がかかりますので、辛抱強く待ちます。
※ エラーが出る場合もありますが、大抵は、Google検索で解決します。
ビルドが終わると、jaxlibのwheelファイルが、distフォルダ下に生成されますので、それをpipコマンドでインストールします。
(jax_cfd) C:/(ダウンロード先)/jax> pip install C:/(ダウンロード先)/jax/dist/jaxlib-0.1.77-cp38-none-win_amd64.whl
続けて、JAX本体を下記コマンドで、インストールします。
(jax_cfd) C:/(ダウンロード先)/jax> pip install -e .
JAX-CFDのインストール
JAX-CFDのインストールは、下記のコマンドを実行します。
(jax_cfd) C:/(ダウンロード先)/jax>pip install jax-cfd
また、デモの計算では、seaborn、xarrayが要求されますので、これらも先に、インストールしておきます。
C:/(ダウンロード先)> pip install seaborn xarray
JAX-CFDのダウンロード
JAX-CFDのソースコードをリンク先から、適当な場所に、ダウンロードあるいは、Gitでクローンします。
デモの実行
JAX-CFDのnotebookフォルダ中に、Jupyter Notebook用のサンプルがいくつかあります。
jax-cfd/notebook
ここでは、demo.ipynbから作成した、下記のdemo.pyファイルを適当なフォルダに置き、実行します。
import jax
import jax.numpy as jnp
import jax_cfd.base as cfd
import numpy as np
import seaborn
import xarray
import matplotlib.pyplot as plt
size = 256
density = 1.
viscosity = 1e-3
seed = 0
inner_steps = 25
outer_steps = 200
max_velocity = 2.0
cfl_safety_factor = 0.5
# Define the physical dimensions of the simulation.
grid = cfd.grids.Grid((size, size), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
# Construct a rondom initial velocity. The `filtered_velocity_field` function
# ensures that the initial velocity is divergence free and it filtered out
# high frequency fluctuations.
v0 = cfd.initial_conditions.filtered_velocity_field(
jax.random.PRNGKey(seed), grid, max_velocity)
# Choose a time step.
dt = cfd.equations.stable_time_step(
max_velocity, cfl_safety_factor, viscosity, grid)
# Define a step function and use it to compute a trajectory.
step_fn = cfd.funcutils.repeated(
cfd.equations.semi_implicit_navier_stokes(
density=density, viscosity=viscosity, dt=dt, grid=grid),
steps=inner_steps)
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
_, trajectory = jax.device_get(rollout_fn(v0))
# JAX-CFD uses GridVariable objects for input/output. These objects contain:
# - array data
# - an "offset" that documents the position on the unit-cell where the data
# values are located
# - grid properties
# - boundary conditions on the variable
with np.printoptions(edgeitems=1):
for i, u in enumerate(trajectory):
print(f'Component {i}: {u}')
# load into xarray for visualization and analysis
ds = xarray.Dataset(
{
'u': (('time', 'x', 'y'), trajectory[0].data),
'v': (('time', 'x', 'y'), trajectory[1].data),
},
coords={
'x': grid.axes()[0],
'y': grid.axes()[1],
'time': dt * inner_steps * np.arange(outer_steps)
}
)
def vorticity(ds):
return (ds.v.differentiate('x') - ds.u.differentiate('y')).rename('vorticity')
(ds.pipe(vorticity).thin(time=20)
.plot.imshow(col='time', cmap=seaborn.cm.icefire, robust=True, col_wrap=5))
plt.show()
実行コマンドは、下記です。
(jax_cfd) C:/work/jax_cfd_demo> python demo.py
要素数は65,025で、内部繰り返し25回、外部繰り返し200回ですが、GPUで計算しているので10秒とかからずに終わます。
出力としては、論文に掲載されている図が描画されます。