How to extract unsupervised clusters from a Dirichlet Process in PyMC3? How to extract unsupervised clusters from a Dirichlet Process in PyMC3? python python

How to extract unsupervised clusters from a Dirichlet Process in PyMC3?


Using a couple of new-ish additions to pymc3 will help make this clear. I think I updated the Dirichlet Process example after they were added, but it seems to have been reverted to the old version during a documentation cleanup; I will fix that soon.

One of the difficulties is that the data you have generated is much more dispersed than the priors on the component means can accommodate; if you standardize your data, the samples should mix much more quickly.

The second is that pymc3 now supports mixture distributions where the indicator variable component has been marginalized out. These marginal mixture distributions will help accelerate mixing and allow you to use NUTS (initialized with ADVI).

Finally, with these truncated versions of infinite models, when encountering computational problems, it is often useful to increase the number of potential components. I have found that K = 30 works better for this model than K = 15.

The following code implements these changes and shows how the "active" component means can be extracted.

from matplotlib import pyplot as pltimport numpy as npimport pymc3 as pmimport seaborn as snsfrom theano import tensor as Tblue = sns.color_palette()[0]np.random.seed(462233) # from random.orgN = 150CENTROIDS = np.array([0, 10, 50])WEIGHTS = np.array([0.4, 0.4, 0.2])x = np.random.normal(CENTROIDS[np.random.choice(3, size=N, p=WEIGHTS)], size=N)x_std = (x - x.mean()) / x.std()fig, ax = plt.subplots(figsize=(8, 6))ax.hist(x_std, bins=30);

Standardized data

K = 30with pm.Model() as model:    alpha = pm.Gamma('alpha', 1., 1.)    beta = pm.Beta('beta', 1., alpha, shape=K)    w = pm.Deterministic('w', beta * T.concatenate([[1], T.extra_ops.cumprod(1 - beta)[:-1]]))    tau = pm.Gamma('tau', 1., 1., shape=K)    lambda_ = pm.Uniform('lambda', 0, 5, shape=K)    mu = pm.Normal('mu', 0, tau=lambda_ * tau, shape=K)    obs = pm.NormalMixture('obs', w, mu, tau=lambda_ * tau,                           observed=x_std)with model:    trace = pm.sample(2000, n_init=100000)fig, ax = plt.subplots(figsize=(8, 6))ax.bar(np.arange(K) - 0.4, trace['w'].mean(axis=0));

We see that three components appear to be used, and that their weights are reasonably close to the true values.

Mixture weights

Finally, we see that the posterior expected means of these three components match the true (standardized) means fairly well.

trace['mu'].mean(axis=0)[:3]

array([-0.73763891, -0.17284594, 2.10423978])

(CENTROIDS - x.mean()) / x.std()

array([-0.73017789, -0.16765707, 2.0824262 ])