luggage baggage

Machine learning, data analysis, web technologies and things around me.

TensorFlow Probability で VAE (Variational AutoEncoder): TFP 入門所感、tfp.distributions の初歩

最近になり、TensorFlow の肩に乗って確率的プログラミングをするためのライブラリ TensorFlow Probability (TFP) v0.5 がリリースされました。私は画像系タスクを TensorFlow を使って解くことが多く、特にこの記事では、画像生成系タスクに取り組むためのツールとして TFP を実際に使った際に学んだ内容について記録しておきます。

TFP に入門してみた所感

今のところのざっくりとした感触は、総じてデータサイエンティスト/アナリストではなく機械学習エンジニア向けの確率的プログラミング言語だなという感じです(ある意味当然ですが)。いくつか項目を挙げると、

  1. eight schools 的な、いわゆる確率モデリングをするのであれば、もっとこなれた他のライブラリを使ったほうが良さそう
  2. VAE (Variational AutoEncoder: 変分オートエンコーダ), Flow-based generative models などニューラルネットの構築が必要となる確率モデルを実装する上では強力なツール
  3. GPU/TPU による高速化の恩恵を受けられる点も Good(私は未検証なので論文ベースの知見)

と思いました。

1番目についていうと、やはり TensorFlow 特有のやりづらさは引き継いでいるところがあります(下の方の「ありがちなミス」欄に書きます)。個人的には2番めの項目が TFP を使う主な動機になります。Pyro という PyTorch ベースの類似ライブラリもありますが、私は TensorFlow ユーザなのでいったん放置です。3番めについては、Google Brain より論文*1が出ており、

With NUTS, we see a 100x speedup on GPUs over Stan and 37x over PyMC3

ということなので、使い方によっては速度面で大きな恩恵を得られる可能性がありそうです。

この記事では、上の2番めの項目に挙げた「ニューラルネット+確率プログラミング」をやる良い事例として、公式に提供されている VAE (Variational AutoEncoder) のサンプルコードを理解できるように、確率分布の作り方・使い方を中心として TFP の基本事項を簡単にまとめます*2。特に重要だと思った batch_shapeevent_shape はある程度厚めに解説したつもりです。VAE の実装や実際に動かした結果は次の記事に集約したいと思います。

(2018.11.14 WIP ですがいったん公開。)

確率分布を表現する tfp.distributions

まず少しだけ全体像を確認すると、TFP には4つのレイヤーがある、と公式では説明されています*3。詳細は省きますが、TensorFlow をベースに、生の確率分布を記述するための tfp.distributions, tfp.bijectors を用意した上で、より高水準のモジュール(確率モデル構築と確率推論)として tfp.edward2tfp.mcmc が提供されています。自分が必要とする抽象度でコードが書ける点は柔軟で良いかなと思います。

これらの中で、VAE を書くためにまず必要となるのが tfp.distributions です。以下、

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tf.enable_eager_execution()

とした状態を前提とします(TF2.0 からは eager execution がデフォルト化する*4ということもあり)。

最初の例:ベルヌーイ分布(コインの裏表)

最も簡単な例として、ベルヌーイ Bernoulli 分布を作ってみましょう。

bern = tfd.Bernoulli(probs=[0.7])
bern 
# <tfp.distributions.Bernoulli 'Bernoulli_1/' batch_shape=(1,) event_shape=() dtype=int32>

ベルヌーイ分布は事象が起こるか否かの2値を取る確率変数の値の分布を表すので、表が出る確率=probs を指定すれば完了です。この分布からサンプリングするには、

bern.sample()
# <tf.Tensor: id=28, shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>

とします。この場合だと、値は array([1], dtype=int32) となりましたね。確率 0.7 で表が出るコインを用意した状態なので、まずは表が出ました。複数回サンプリングするためには、引数として回数を入れます。

bern.sample(3)
# <tf.Tensor: id=103, shape=(3, 1), dtype=int32, numpy=
# array([[1],
#        [1],
#        [1]], dtype=int32)>

さらに、引き出すサンプルの shape を指定することもできます。

bern.sample([2, 3])  # (2, 3) 型のサンプルを引く
# <tf.Tensor: id=127, shape=(2, 3, 1), dtype=int32, numpy=
# array([[[1],
#         [1],
#         [1]],

#        [[0],
#         [1],
#         [1]]], dtype=int32)>

ほかにも多数種の確率分布を作れたり、対数確率や KL ダイバージェンスを計算できたりしますが、基本はこれだけです。生成モデルなどを扱う際には確率変数というよりも確率分布自体に関心がある場合が多く、その意味では確率分布関連のオブジェクトとメソッドをまず網羅しているのは合理的だと思いました。

大事な概念:batch_shape と event_shape

batch_shape: 独立同分布の確率変数をまとめるもの(ミニバッチ的な)

ところで、上の例をよく見てみると、確率分布 bernbatch_shapeevent_shape という属性を持っていることに気がつくと思います。実際、

bern.batch_shape  # TensorShape([Dimension(1)])

となります。この2つの属性は地味なようですが重要で、おそらく TFP を使い始めた当初は、これらの理解不足が原因でコードが走らないことが頻繁に起こるような気がします。

まず batch_shape を理解するために、次のようにしてみましょう。

bern3 = tfd.Bernoulli(probs=[.3, .5, .7])
bern3
# <tfp.distributions.Bernoulli 'Bernoulli/' batch_shape=(3,) event_shape=() dtype=int32>
bern3.sample()
# <tf.Tensor: id=211, shape=(3,), dtype=int32, numpy=array([0, 0, 1], dtype=int32)>

今度は batch_shape(3,) となり、サンプリング結果は長さ3の配列となりました。一方で even_shape は空ですね。これらは、bern3 が「1次元の独立な3つのベルヌーイ分布を“まとめた”もの」であることを表しています。“まとめた”とは、「便宜上3つの分布から一挙にサンプリングできるようにした」ということで、ニューラルネットの文脈でいうと、この batch_shape は例えばバッチサイズが該当したりします。

event_shape: 確率変数の次元を表すもの

event_shape については、サンプリングではなく確率値を計算させた際に違いが出てきます。1変数の例に戻ると、

bern.prob([1.])  # 確率変数が 1 となる確率を求める
# <tf.Tensor: id=248, shape=(1,), dtype=float32, numpy=array([0.7], dtype=float32)>

とすると分かる通り、確率分布を決めたとき、ある値に対応する確率値を求めることができます。これは p(x) のことで、ここではコインの表 (x=1) が出る確率を 0.7 と定義していたので当然の結果が返ってきたことになります。3変数の例であれば、

bern3.prob([1., 1., 1.])  # 確率変数が各々 1 となる確率を求める
# <tf.Tensor: id=262, shape=(3,), dtype=float32, numpy=array([0.29999998, 0.5       , 0.7       ], dtype=float32)>

となって、要素数3の配列が返ってきました。これが意味するところは、便宜上3つ組にされた独立な確率分布のそれぞれが、各々 1 という値を返す確率を計算しましたよ、ということです。

確率を扱う際、複数の確率変数が(一般には相関を持つ)組としてある値を取る確率、つまり同時確率を計算するシーンは多いと思います。TFP で(独立な変数の)同時確率を表現するために使うのが、tfd.Independent です。言葉よりもコードの方がわかりやすいかもしれません:

bern3_joint = tfd.Independent(bern3, reinterpreted_batch_ndims=1)
bern3_joint
# <tfp.distributions.Independent 'IndependentBernoulli/' batch_shape=() event_shape=(3,) dtype=int32>

今度は、event_shape が空ではなくなりましたね!試しにサンプリングしてみましょう。

bern3_joint.sample()
# <tf.Tensor: id=300, shape=(3,), dtype=int32, numpy=array([0, 1, 1], dtype=int32)>

これは、前に見た結果と同じですね。つまり単にサンプリングするだけであれば、bern3.sample()bern3_joint.sample() の間に違いはありません。確率を計算させるとどうでしょうか?

bern3_joint.prob([1., 1., 1])
# <tf.Tensor: id=317, shape=(), dtype=float32, numpy=0.105000004>

出力は単一の値 0.105000004 となり、ようやくこれまでと違う結果が出てきました。

少し振り返ると、いま考えている分布は probs=[.3, .5, .7] のベルヌーイ分布を結合したもので、0.3 * 0.5 * 0.7 = 0.105 となりますね。つまり、bern3_joint は、「3次元の確率変数で、各次元が独立なベルヌーイ分布に従う」分布を表現しているということです。この意味で、event_shape は確率変数の次元を表していると思ってよいです。

多変数正規分布の場合

event_shape をもう少し直感的に分かるようにするため、次に多次元正規分布を作ってみます。簡単のため次元は2で、無相関な2つの変数を扱っているとしましょう。

mvn = tfd.MultivariateNormalDiag(loc=[0., 0.], scale_diag=[1., 1.])  # 共分散行列が単位行列になっている
mvn
# <tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag/' batch_shape=() event_shape=(2,) dtype=float32>

そうすると、event_shape が (2,) となりました。原点における確率を求めると、

mvn.prob([0., 0.])
# <tf.Tensor: id=440, shape=(), dtype=float32, numpy=0.15915495>

0.15915495 という値が出てきますが、これは  1/2\pi に一致します。独立な正規分布2つの積を原点で評価した値になるということですね。

ありがちなミス

さて、TFP は TF ベースゆえに TF ならではの低レイヤー感を味わうことがあったり、ちゃんとドキュメントを読まないとスルーしてしまいそうな不都合があるので、3つほどご紹介しておきます。

変数の型が合っていない

意図せず不合理な結果を返す場合

shape のミスマッチ