読者です 読者をやめる 読者になる 読者になる

混合ポアソン分布を題材に、変分ベイズ法を理解する④

実装 ベイズ Python

こんにちは。吉田弁二郎です。有給消化中です。

前回、混合ポアソン分布に対して変分ベイズ法を適用し、実際に学習が行われた様子を見ました。
yoshidabenjiro.hatenablog.com

ところで、推定した分布が真の分布とどの程度異なっているのか、やはり定量的に把握したいと思うのが人情というものです。そこで今回は補足的な内容として、Kullback-Liebler ダイバージェンス(もしくは KL 情報量)を基準として学習の進捗を観察します。

Kullback-Leibler ダイバージェンス

まず理論的な背景を整理します。確率変数  x についての二つの確率分布  p(x) q(x) があるとします。この時、確率分布間の「距離」尺度として Kullback-Leibler ダイバージェンスと呼ばれる次の量を計算することができます:
\begin{align}
KL(q||p) &\equiv -\int dx q(x) \log\frac{p(x)}{q(x)}\\
&\geq 0.
\end{align}この不等式は  \log に関する関係式  x - 1 \geq \log (x) から従います。等号成立は、全ての  x について  q(x)=p(x) となる時に限ります。推定された分布が真の分布に近ければ近いほど、この値は下限であるゼロに近づくと期待されます。

「距離」と書いたのは、 q p の入れ替えについて対称ではないので、距離としての通常の性質を満たさないためです。とはいえ便利な量なのでよく使われるのが実情です。情報幾何学において非常に重要な役割を果たすことは、また改めて書きたいと思います。

混合ポアソンモデルにおける KL ダイバージェンスの具体的な表式

さて、今回の混合ポアソン分布推論問題においては、潜在変数  Z および分布のパラメータである  \lambda,  \pi について変分近似  q(Z, \lambda, \pi) = q(Z)q(\lambda, \pi) を課して計算をすると、結果的に  q(Z, \lambda, \pi) = q(Z)q(\lambda)q(\pi) となることがわかりました。これが真の事後分布  P(Z|X) = P(X, Z) / P(X) とどの程度離れているか/近づいているか、KL ダイバージェンスによって評価してみましょう。なお、分布推定の更新式についてはyoshidabenjiro.hatenablog.comにまとめています。

通常は  P(Z|X) に対しての  q(Z) の KL ダイバージェンスを計算しますが、今回は計算の都合上、 q(Z) に対して  P(X, Z) の KL ダイバージェンスを求めます( Z,  \lambda,  \pi をまとめて  Z と表記)。この時、
\begin{align}
KL(q||p) &= -\int dZ q(Z)\log\frac{P(X, Z)}{q(Z)}\\
&= -\int dZ q(Z)\log\frac{P(Z|X)}{q(Z)} -\log P(X)\\ \label{1} \tag{1}
&\geq -\log P(X) \geq 0
\end{align}となり、データ  X に依存したゼロ以上の下限を持つことが示されます。

さて、前回までに得られた結果を使うと、推定されたパラメータ  \pi_{nk}, r_{k}, s_{kd}, \alpha_{k} およびデータを生成した分布のパラメータ  r^{(0)}, s^{(0)}, \alpha_{k}^{(0)} により、\eqref{1} は下記のように表されます:
\begin{align}
KL(q||p) &= -\langle\log P(X|Z, \lambda)\rangle_{Z, \lambda} + \langle\log q(Z)\rangle_{Z} - \langle\log P(Z|\pi)\rangle_{Z, \pi}\\
&\quad + \langle\log q(\lambda)\rangle_{\lambda} - \langle\log P(\lambda)\rangle_{\lambda} + \langle\log q(\pi)\rangle_{\pi} - \langle\log P(\pi)\rangle_{\pi}\\
&= -\sum_{n,k,d}\pi_{nk}\Bigl(x_{nd}\bigl(\psi(s_{kd}) - \log (r_{k})\bigr) - \frac{s_{kd}}{r_{k}} - \log x_{nd}!\Bigl)\\
&\quad + \sum_{n,k}\pi_{nk}\Bigl(\log\pi_{nk} - \bigl(\psi(\alpha_{k}) - \psi(\alpha_{0})\bigr)\Bigr)\\
&\quad + \sum_{k,d}\Bigl\{s^{(0)}\log\frac{r_{k}}{r^{(0)}} + (s_{kd} - s^{(0)})\psi(s_{kd}) - (r_k - r^{(0)})\frac{s_{kd}}{r_{k}} + \log\frac{\Gamma(s^{(0)})}{\Gamma(s_{kd})}\Bigr\}\\
&\quad + \log\frac{C(\alpha_{0})}{C(\alpha_0^{(0)})} + \sum_{k}(\alpha_k - \alpha_{k}^{(0)})\bigl(\psi(\alpha_{k}) - \psi(\alpha_{0})\bigr). \label{2} \tag{2}
\end{align}各推定ステップにおける  \pi_{nk}, r_{k}, s_{kd}, \alpha_{k} を都度代入すれば、KL ダイバージェンスの時間発展が計算できます。

数値計算の結果

このようなサンプルデータf:id:yoshidabenjiro:20170107002138p:plainに対して\eqref{2}を適用してみると、KL ダイバージェンスの推定ステップに対する変化は下図のようになりました。f:id:yoshidabenjiro:20170107002150p:plain縦軸が KL ダイバージェンスの値、横軸がパラメータ更新回数です。図を観察すると、

  1. パラメータを更新するたびにダイバージェンスは小さくなる
  2. 数回のパラメータ更新でプラトーに達する
  3. プラトーに達するための更新回数はサンプルデータに依存する
  4. 下限はゼロより大きい

ことがわかると思います。特に4.は、\eqref{1}で  -\log P(X) の下限が生じると述べたことと合致しますね。この程度のデータであれば10回に満たないパラメータ更新を繰り返せば十分なようです。データが高次元で今回のような可視化が難しい場合であっても、KL ダイバージェンスを計算することで学習完了をおおよそ把握することができそうです。

今回 KL ダイバージェンスを計算するために書いたコードは blog/pmm.py at master · dlnp2/blog · GitHub に置いてあります。