JAXでGPUを使う
Pytorchだと.cuda()などでGPUを使うのだが、JAXの場合はpipインストール時にcudaのバージョンを指定することでGPUを使うことができる。
環境にインストールされているCUDAのバージョンが10.02の場合、以下を入力しJAXのインストールを行う。
pip install -U jax -f https://storage.googleapis.com/jax-releases/jax_releases.htmlpip install -U jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.htmlJAXがGPUを使っているかどうかは、以下で確かめる。
import jaxjax.default_backend()上記のコマンドを入力後、‘gpu’がかえってこれば正常に使えているということになる。
また、以下のコマンドを打つと、
jax.local_devices()使われているGPUのidと数が以下のように把握可能。