The Remarkable k-means++

October 1, 2012
By

(This article was originally published at Normal Deviate, and syndicated at StatsBlogs.)

1. The Problem

One of the most popular algorithms for clustering is k-means. We start with {n} data vectors {Y_1,\ldots, Y_n\in \mathbb{R}^d}. We choose {k} vectors — cluster centers — {c_1,\ldots, c_k \in \mathbb{R}^d} to minimize the error

\displaystyle  R(c_1,\ldots,c_k) = \frac{1}{n}\sum_{i=1}^n \min_{1\leq j \leq k} ||Y_i - c_j||^2.

Unfortunately, finding {c_1,\ldots, c_k} to minimize {R(c_1,\ldots,c_k)} is NP-hard. The usual iterative method, \hrefnosnap{http://en.wikipedia.org/wiki/Lloyd is easy to implement but it is unlikely to come close to minimizing the objective function. So finding

\displaystyle  \min_{c_1,\ldots, c_k}R(c_1,\ldots,c_k)

isn’t feasible.

To deal with this, many people choose random starting values, run the {k}-means clustering algorithm then rinse, lather and repeat. In general, this may work poorly and there is no theoretical guarantee of getting close to the minimum. Finding a practical method for approximately minimizing {R} is thus an important practical problem.

2. The Solution

David Arthur and Sergei Vassilvitskii came up with a wonderful solution in 2007 known as k-means++.

The algorithm is simple and comes with a precise theoretical guarantee.

The first step is to choose a data point at random. Call this point {s_1}. Next, compute the squared distances

\displaystyle  D_i^2 = ||Y_i - s_1||^2.

Now choose a second point {s_2} from the data. The probability of choosing {Y_i} is {D_i^2/\sum_j D_j^2}. Now recompute the distance as

\displaystyle  D_i^2 = \min\Bigl\{ ||Y_i - s_1||^2, ||Y_i - s_2||^2 \Bigr\}.

Now choose a third point {s_3} from the data where the probability of choosing {Y_i} is {D_i^2/\sum_j D_j^2}. We continue until we have {k} points {s_1,\ldots,s_k}. Finally, we run {k}-means clustering using {s_1,\ldots,s_k} as starting values. Call the resulting centers {\hat c_1,\ldots, \hat c_k}.

Arthur and Vassilvitskii prove that

\displaystyle  \mathbb{E}[R(\hat c_1,\ldots,\hat c_k)] \leq 8 (\log k +2) \min_{c_1,\ldots, c_k}R(c_1,\ldots,c_k).

The expected value is over the randomness in the algorithm.

There are various improvements to the algorithm, both in terms of computation and in terms of getting a sharper performance bound.

This is quite remarkable. One simple fix, and an intractable problem has become tractable. And the method comes armed with a theorem.

3. Questions

  1. Is there an R implementation? It is easy enough to code the algorithm but it really should be part of the basic k-means function in R.
  2. Is there a version for mixture models? If not, it seems like a paper waiting to be written.
  3. Are there other intractable statistical problems that can be solved using simple randomized algorithms with provable guarantees? (MCMC doesn’t count because there is no finite sample guarantee.)

4. Reference

Arthur, D. and Vassilvitskii, S. (2007). k-means++: The advantages of careful seeding. Proceedings of the eighteenth annual ACM-SIAM symposium on Discrete algorithms. 1027–1035.




Please comment on the article here: Normal Deviate

Tags:

Subscribe

Email:

  Subscribe