★データ解析備忘録★

ゆる〜い技術メモ

多項分布とディリクレ分布のまとめと可視化

多項分布とその共役事前分布について、可視化をしながら整理してみたいと思います。
どちらかというと、可視化をしてパラメーターで分布の形がどう変わるのかを見ることがメインです。

多項分布とは

二項分布の一般化と考えればよいです。
「コインを投げた時の表裏の分布」が二項分布なら、多項分布は「さいころを投げて出る目の分布」になりますね。

確率密度関数
{ \displaystyle
Multi(\boldsymbol{x} | n,p) = \frac{n!}{n_1! \cdots n_k!}p_1^{n_1} \cdots p_k^{n_k}
}
で表され、 p_1 + \cdots + p_k = 1になります。。
さっきの言い方を上の式っぽく一般化すると、多項分布とは
確率  P_iで事象 A_iが起こる (i=1,\cdots,k_i=1,\cdots,k)ような試行を n回行ったとき「どの事象が何回起こったか」を表す分布です。また、k=2のとき二項分布になります

その他、平均(期待値)は
 E[N_i] = np_i
分散は
 V[N_i] = np_i(1-p_i)
となります。

ディリクレ分布とは

自然言語処理ではトピックモデルの潜在ディリクレ過程(Latent Dirichlet Allocation; LDA)でおなじみですね。
ディリクレ分布は多項分布の共役事前分布として知られており、一言でいうと「さいころの目の出やすさ」の分布です。
多項分布は「出る目の分布」でディリクレ分布が「目の出やすさの分布」ということで、ベイズ統計では多項分布の事前分布としてディリクレ分布がよく使われます。
また、よく使われる理由のもう一つとして多項分布とディリクレ分布の積はディリクレ分布の形だということもあります。
つまり、事後確率分布が事前確率分布と同じ関数形になるということです。*1
(ちなみに二項分布の場合はベータ分布が同じ関係になります。なので、ディリクレ分布はベータ分布の一般化といえます。)

式で表すと

{ \displaystyle
Dir(\boldsymbol{p} | \boldsymbol{\alpha}) = \frac{\Gamma(\sum_{k=1}^K\alpha_k)}{\prod_{k=1}^K \Gamma(\alpha_k)}\prod_{k=1}^K p_k^{\alpha_k-1}
}
ただし、\Gamma(\cdot)はガンマ関数で自然数mに対して\Gamma(m+1)=m!

となります。
数式的な言葉の説明だと
「あるn個の事象についてi番目の事象が\alpha_i-1回発生した場合に、その事象の生起確率がx_iである確率」の分布ということになります。
多項分布とは変数にしているものが異なるので注意しましょう。

x_iの平均(期待値)は
{ \displaystyle
E[x_i] = \frac{\alpha_i}{\sum_{j=1}^n\alpha_j}
}
分散は
{ \displaystyle
V[x_i] = \frac{E[x_i]\left(1-E[x_i]\right)}{1+\sum_{j=1}^n\alpha_j}
}
です。このあたりの細かい導出等は省略します。*2

可視化してみる

4変量以上だと目に見える形にできないので3変量で可視化をしてみます。
MatlabPythonを使いますが環境は以下の通りです。

多項分布

http://jp.mathworks.com/help/stats/mnpdf.html
を参考にして*3Matlabで3Dの可視化をします。
まず、標本サイズを n-10とし、3つの結果が現れる確率は p_1 = \frac{1}{2}, p_2 =\frac{1}{3}, p_3=\frac{1}{6}とします。

p = [1/2 1/3 1/6];
n = 10;
x1 = 0:n;
x2 = 0:n;
[X1,X2] = meshgrid(x1,x2);
X3 = n-(X1+X2);

Y = mnpdf([X1(:),X2(:),X3(:)],repmat(p,(n+1)^2,1));

Y = reshape(Y,n+1,n+1);
bar3(Y)
h = gca;
h.XTickLabel = [0:n];
h.YTickLabel = [0:n];
xlabel('x_1')
ylabel('x_2')
zlabel('Probability Mass')
title('Trinomial Distribution')

f:id:songcunyouzai:20160303140446p:plain

ちなみに n=100のときは以下のようになりました。
f:id:songcunyouzai:20160303140932p:plain

ディリクレ分布

MATLAB Central - how to plot 3-dimension Dirichlet distribution
を参考にしてMatlabで3Dの可視化をし、
Visualizing Dirichlet Distributions with Matplotlib
を参考にして*4Pythonのmatplotlibで同じものを上から見た2Dに落とし込みます。

基本的なMatlabコードとPythonコードは以下のとおりです。

Matlab

alpha = [2 3 4];
x1 = linspace(0,1,101);
x2 = linspace(0,1,101);
[X1,X2] = ndgrid(x1,x2);
X3 = 1 - X1 - X2;
bad = (X1+X2 > 1); X1(bad) = NaN; X2(bad) = NaN; X3(bad) = NaN;

betaConst = exp(sum(gammaln(alpha))-gammaln(sum(alpha)));
F = (X1.^(alpha(1)-1) .* X2.^(alpha(2)-1) .* X3.^(alpha(3)-1)) / betaConst;

figure, surf(X1,X2,F,'EdgeColor','none');
xlabel('x1'); ylabel('x2'); zlabel('f(x1,x2,1-x1-x2)');
view(-160,40);


Python

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
from functools import reduce

corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])

refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=4)

plt.figure(figsize=(8, 4))
for (i, mesh) in enumerate((triangle, trimesh)):
    plt.subplot(1, 2, i+ 1)
    plt.triplot(mesh)
    plt.axis('off')
    plt.axis('equal')

midpoints = [(corners[(i + 1) % 3] + corners[(i + 2) % 3]) / 2.0 \
             for i in range(3)]
def xy2bc(xy, tol=1.e-3):
    '''Converts 2D Cartesian coordinates to barycentric.'''
    s = [(corners[i] - midpoints[i]).dot(xy - midpoints[i]) / 0.75 \
         for i in range(3)]
    return np.clip(s, tol, 1.0 - tol)
    
class Dir(object):
    def __init__(self, alpha):
        from math import gamma
        from operator import mul
        self._alpha = np.array(alpha)
        self._coef = gamma(np.sum(self._alpha)) / \
                     reduce(mul, [gamma(a) for a in self._alpha])
    def pdf(self, x):
        '''Returns pdf value for `x`.'''
        from operator import mul
        return self._coef * reduce(mul, [xx ** (aa - 1)
                                         for (xx, aa)in zip(x, self._alpha)])

def draw_pdf(dist, nlevels=200, subdiv=8, **kwargs):
    import math

    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]

    plt.tricontourf(trimesh, pvals, nlevels, **kwargs)
    plt.axis('equal')
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis('off')

パラメーターの値を変えていきます。
Matlabではalpha = [1 1 1]のところを変え、Pythonではdraw_pdf(Dir([1, 1, 1]))のところを変えて入力することで画像を出力します。
MatlabPythonで色が微妙に違うのはご容赦ください。
 Dir(1, 1, 1)
f:id:songcunyouzai:20160303141329p:plain
f:id:songcunyouzai:20160229171915p:plain

 Dir(10, 10, 10)
f:id:songcunyouzai:20160303141614p:plain
f:id:songcunyouzai:20160229172033p:plain

 Dir(5, 5, 5)
f:id:songcunyouzai:20160303141734p:plain
f:id:songcunyouzai:20160229172110p:plain

 Dir(10, 1, 1)
f:id:songcunyouzai:20160303141910p:plain
f:id:songcunyouzai:20160229172341p:plain

 Dir(2, 3, 4)
f:id:songcunyouzai:20160301005652j:plain
f:id:songcunyouzai:20160229172249p:plain

 Dir(2, 5, 15)
f:id:songcunyouzai:20160303144025p:plain
f:id:songcunyouzai:20160303144300p:plain

本当はパラメーターを変えた時のアニメーションとかやりたかったんですが、僕の知識ではできなかったので、どなたかわかる方がいたらご享受願います。

*1:数式的証明は今回は省きます。

*2:興味ある方は数学書をどうぞ

*3:ここはコードをそのままなぞっています。

*4:上記記事はPython2系で書かれてますが、3系に移植する際、reduce関数は当ブログ4行目のようにしてやる必要があります。