luggage baggage

Machine learning, data analysis, web technologies and things around me.

Optuna で混合ガウス分布の混合比を推定(パラメータ間に依存関係・制約条件がある場合の例)

機械学習モデルのハイパーパラメータを最適化するのに Optuna を使っている人も多いと思います。最適化目標(深層学習のロスや RMSE など)を決め、探索空間を適当に設定するだけでいい感じにパラメータを求めてくれるので便利ですね。

ところで、Optuna におけるパラメータの指定方法は、基本的には各パラメータごとに独立におこなうものと理解しています。これは例えば機械学習モデルのハイパラサーチのときには自然な設定だと思われるのですが、パラメータの和が一定値になる、みたいな制約(等号制約)下ではどのようにすればよいか理解していませんでした。

そこでこの記事では、この種の問題の例として、混合ガウス分布の混合比を Optuna に決定させてみます。混合ガウス分布は複数のガウス分布の重み付け和として得られる分布であり、この重み(混合比)は当然和が 1 になります。

扱うデータ

今回は簡単のため、各ガウス分布の平均と共分散行列は最適化する対象とはせず固定しておきます。混合要素は3つ、x軸上に等間隔に並んでいるとし、各々に対する真の重みをそれぞれ 0.2, 0.3, 0.5 としましょう。分布は次の図のようになります。
f:id:yoshidabenjiro:20200607193957p:plain
このとき、Optuna を使って混合比 0.2, 0.3, 0.5 に近い値が出力されてくるかを確認していきましょう。

objective の定式化

Optuna を使う際には何らかの最適化目標を与える必要があります。ここでは、negative log likelihood を最小化することによりパラメータを決めていきます。対応する Python コードは次のようになります。

def objective(data, trial):
    # Minimze negative log likelihood of mixed gaussian distribution.

    w1 = trial.suggest_uniform("w1", 0, 1.0)
    w2 = trial.suggest_uniform("w2", 0, 1.0)
    w3 = trial.suggest_uniform("w3", 0, 1.0)
    ws = np.array([w1, w2, w3])
    ws /= ws.sum()  # normalize parameters for the summation to be 1.
    nll = 0.0
    for i in range(len(data)):
        nll += -np.log(
            sum(
                [
                    ws[j] * np.exp(-np.sum((data[i] - gt_mean[j]) ** 2) / 2)
                    for j in range(len(gt_mean))
                ]
            )
        )
    return nll

ここで重要なのは、最適化対象となる混合比 ws をその和で割っている行です。これにより、初めは独立に一様分布からサンプリングされた3つの変数の和を 1 にした状態で negative log likelihood の計算に取り込むことができます。あとは通常どおり study オブジェクトをつくり、最適化計算を実行するだけです。簡単ですね!

結果

n_trials=100 として最終的に得た結果は

Best params (raw): [0.30270369 0.46664845 0.7754236 ]
Best params (normalized): [0.19595316 0.30208168 0.50196516]

となり、真の値 0.2, 0.3, 0.5 にそこそこ近いものが得られています。一方、上のコードで重みを正規化しない場合にどうなるかというと、

Best params (raw): [0.97299614 0.95032109 0.99884739]
Best params (normalized): [0.33297102 0.32521134 0.34181763]

のように全く的はずれな結果となりました(混合ガウス分布の対数尤度を計算したことになっていないので当然)。

混合比を推論する場合は普通 EM アルゴリズムなどを使い高精度に実行すると思いますが、今回はあくまでも Optuna の使い方の検証ということでご容赦ください。ディリクレ分布のように多変量かつ要素の和が 1 になるよう保証されたパラメータを探索するオプションがほしいですね。

使ったコードはこちらに置いておきました。
blog/estimate_2d_gmm_with_optuna.py at master · dlnp2/blog · GitHub