JAXでGPUを使う

Pytorchだと.cuda()などでGPUを使うのだが、JAXの場合はpipインストール時にcudaのバージョンを指定することでGPUを使うことができる。

環境にインストールされているCUDAのバージョンが10.02の場合、以下を入力しJAXのインストールを行う。

pip install -U jax -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install -U jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html

JAXがGPUを使っているかどうかは、以下で確かめる。

import jax
jax.default_backend()

上記のコマンドを入力後、‘gpu’がかえってこれば正常に使えているということになる。

また、以下のコマンドを打つと、

jax.local_devices()

使われているGPUのidと数が以下のように把握可能。