luggage baggage

Machine learning, data analysis, web technologies and things around me.

jax を使いコンタクトマップをマイクロ秒で計算する

先日、PDB ファイルから読んだタンパク質座標データを元に、コンタクトマップを numpy でミリ秒オーダーで計算する方法を書きました。
PDB ファイルを読み込みコンタクトマップをミリ秒で計算する - luggage baggage
コンタクトマップとは、3次元空間内に存在するタンパク質を構成する炭素原子同士の距離を計算し(distance map)、一定の閾値以下であればコンタクトしている、と判定するものです。

今回は、Google が実験的に公開している jax という "numpy with GPU/TPU" とでも言えるライブラリ*1を使うことで、マイクロ秒オーダーで計算する方法を紹介したいと思います。なお jax の大きな特徴の一つは自動微分機能ですが、今回はこれに触れず、単に numpy like な API を使った計算を手軽に GPU 上で走らせるためのものとして使います。

(2020.02.18追記) 先日の記事にも追記しましたが、scipy.spatial.distance.cdist を使うと普通に CPU 上でマイクロ秒が達成できます。今回紹介している jax を使うと、手元の環境では scipy 版の倍速程度まで高速化できることが確認できました。ただし、本記事では通常の numpy とほとんどコードを変えずに 30 倍以上の高速化を jax で実現できる、というのが主な内容のため、広いケースでその使用を検討してよいように思います。

検証環境

  • Ubuntu 18.04
  • Python 3.6.10
  • numpy 1.18.1
  • jax 0.1.59
  • jaxlib 0.1.38
  • CUDA Version: 10.1

なお、今回は conda を使っておらず、pip でライブラリのインストールをしています。そのため numpy は MKL を使わない状態で入っています。

numpy.show_config()

"""
blas_mkl_info:
  NOT AVAILABLE
blis_info:
  NOT AVAILABLE
openblas_info:
    libraries = ['openblas', 'openblas']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None)]
blas_opt_info:
    libraries = ['openblas', 'openblas']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None)]
lapack_mkl_info:
  NOT AVAILABLE
openblas_lapack_info:
    libraries = ['openblas', 'openblas']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None)]
lapack_opt_info:
    libraries = ['openblas', 'openblas']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None)]
"""

jax のインストール

公式*2に従います。私の環境では、次の設定で問題なくインストールできました(python -m venv .venv で仮想環境を設定後)。

PYTHON_VERSION=cp36  # alternatives: cp35, cp36, cp37, cp38
CUDA_VERSION=cuda101  # alternatives: cuda92, cuda100, cuda101, cuda102
PLATFORM=linux_x86_64  # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install -U pip
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.38-$PYTHON_VERSION-none-$PLATFORM.whl

pip install --upgrade jax  # install jax

jax による計算の高速化

jax には numpy と同様の API (jax.numpy.linalg.norm など) が実装されており、あたかも通常の numpy を使っているかのようにコードを書くことができます。また、jax.jit により Python 関数を XLA 最適化されたコードにコンパイルでき、大幅な高速化が手軽に実行できます。

全体としては、次のような感じになります。

import jax.numpy as np
import numpy as onp  # jax.numpy と区別
from jax import devices, device_put, jit


# 使用する GPU デバイスを明示的に指定。特に入れなくてもよい
device = devices()[0]

# coords は、前回の記事で用意したタンパク質の座標データを numpy.array に格納したもの。
# device_put により numpy.array を GPU に載せておきます。
coords_gpu = device_put(coords, device=device)

# ベンチマーク用 numpy 版 distance map 計算関数
def distance_map(coords):
    return onp.linalg.norm(coords[:, None, :] - coords, axis=-1)

# 上記関数の jax 版。jax.numpy.linalg は、numpy.linalg に相当。
# None による axis の追加が numpy 同様にできており地味に便利(numba ではできなかった記憶)
def distance_map_jax(coords):
    return np.linalg.norm(coords[:, None, :] - coords, axis=-1)

# 上記関数の jit コンパイル版
distance_map_jit = jit(distance_map_jax)

これらの関数を使って distance map を計算した際の処理時間を比べます。

まずベースラインとして、通常の numpy では

%timeit distance_map(coords)
# 6.06 ms ± 646 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

となり、ミリ秒オーダーで完了しています。

jax 版は、

%timeit distance_map_jax(coords_gpu)
# 3.36 ms ± 423 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

で倍速程度になっているようです。

ところが jax.jit を使うと、

%timeit distance_map_jit(coords_gpu)
# 178 µs ± 9.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

となり、マイクロ秒オーダーで計算が完了しています。約34倍の高速化が簡単に実現できました!jit は一度コンパイルする必要があるので、%timeit を2回呼んだうちの2回めの結果を記載しています。

まとめ

jax による distance map (contact map はこれを元に閾値判定するのみ) の高速化手法を紹介しました。非常に手軽だし、何か計算のボトルネックがあって GPU が利用できるような場合には積極的に使ってみてよいのではないでしょうか。ただし、まだ Windows 対応されていないとか、どのような関数に対しても使えるわけではない、jax 配列は気軽に値の代入ができない、などの制約があったりもする*3ため、自分のやりたいことと適宜相談しつつの利用になりそうです。