2
雷鋒網按:本文作者賈志剛,原文載于作者個人博客,雷鋒網已獲授權。
高斯混合模型(GMM)在圖像分割、對象識別、視頻分析等方面均有應用,對于任意給定的數據樣本集合,根據其分布概率, 可以計算每個樣本數據向量的概率分布,從而根據概率分布對其進行分類,但是這些概率分布是混合在一起的,要從中分離出單個樣本的概率分布就實現了樣本數據聚類,而概率分布描述我們可以使用高斯函數實現,這個就是高斯混合模型-GMM。

這種方法也稱為D-EM即基于距離的期望最大化。
1. 初始化變量定義-指定的聚類數目K與數據維度D
2. 初始化均值、協(xié)方差、先驗概率分布
3. 迭代E-M步驟
- E步計算期望
- M步更新均值、協(xié)方差、先驗概率分布
-檢測是否達到停止條件(最大迭代次數與最小誤差滿足),達到則退出迭代,否則繼續(xù)E-M步驟
4. 打印最終分類結果
package com.gloomyfish.image.gmm;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
*
* @author gloomy fish
*
*/
public class GMMProcessor {
public final static double MIN_VAR = 1E-10;
public static double[] samples = new double[]{10, 9, 4, 23, 13, 16, 5, 90, 100, 80, 55, 67, 8, 93, 47, 86, 3};
private int dimNum;
private int mixNum;
private double[] weights;
private double[][] m_means;
private double[][] m_vars;
private double[] m_minVars;
/***
*
* @param m_dimNum - 每個樣本數據的維度, 對于圖像每個像素點來說是RGB三個向量
* @param m_mixNum - 需要分割為幾個部分,即高斯混合模型中高斯模型的個數
*/
public GMMProcessor(int m_dimNum, int m_mixNum) {
dimNum = m_dimNum;
mixNum = m_mixNum;
weights = new double[mixNum];
m_means = new double[mixNum][dimNum];
m_vars = new double[mixNum][dimNum];
m_minVars = new double[dimNum];
}
/***
* data - 需要處理的數據
* @param data
*/
public void process(double[] data) {
int m_maxIterNum = 100;
double err = 0.001;
boolean loop = true;
double iterNum = 0;
double lastL = 0;
double currL = 0;
int unchanged = 0;
initParameters(data);
int size = data.length;
double[] x = new double[dimNum];
double[][] next_means = new double[mixNum][dimNum];
double[] next_weights = new double[mixNum];
double[][] next_vars = new double[mixNum][dimNum];
List<DataNode> cList = new ArrayList<DataNode>();
while(loop) {
Arrays.fill(next_weights, 0);
cList.clear();
for(int i=0; i<mixNum; i++) {
Arrays.fill(next_means[i], 0);
Arrays.fill(next_vars[i], 0);
}
lastL = currL;
currL = 0;
for (int k = 0; k < size; k++)
{
for(int j=0;j<dimNum;j++)
x[j]=data[k*dimNum+j];
double p = getProbability(x); // 總的概率密度分布
DataNode dn = new DataNode(x);
dn.index = k;
cList.add(dn);
double maxp = 0;
for (int j = 0; j < mixNum; j++)
{
double pj = getProbability(x, j) * weights[j] / p; // 每個分類的概率密度分布百分比
if(maxp < pj) {
maxp = pj;
dn.cindex = j;
}
next_weights[j] += pj; // 得到后驗概率
for (int d = 0; d < dimNum; d++)
{
next_means[j][d] += pj * x[d];
next_vars[j][d] += pj* x[d] * x[d];
}
}
currL += (p > 1E-20) ? Math.log10(p) : -20;
}
currL /= size;
// Re-estimation: generate new weight, means and variances.
for (int j = 0; j < mixNum; j++)
{
weights[j] = next_weights[j] / size;
if (weights[j] > 0)
{
for (int d = 0; d < dimNum; d++)
{
m_means[j][d] = next_means[j][d] / next_weights[j];
m_vars[j][d] = next_vars[j][d] / next_weights[j] - m_means[j][d] * m_means[j][d];
if (m_vars[j][d] < m_minVars[d])
{
m_vars[j][d] = m_minVars[d];
}
}
}
}
// Terminal conditions
iterNum++;
if (Math.abs(currL - lastL) < err * Math.abs(lastL))
{
unchanged++;
}
if (iterNum >= m_maxIterNum || unchanged >= 3)
{
loop = false;
}
}
// print result
System.out.println("=================最終結果=================");
for(int i=0; i<mixNum; i++) {
for(int k=0; k<dimNum; k++) {
System.out.println("[" + i + "]: ");
System.out.println("means : " + m_means[i][k]);
System.out.println("var : " + m_vars[i][k]);
System.out.println();
}
}
// 獲取分類
for(int i=0; i<size; i++) {
System.out.println("data[" + i + "]=" + data[i] + " cindex : " + cList.get(i).cindex);
}
}
/**
*
* @param data
*/
private void initParameters(double[] data) {
// 隨機方法初始化均值
int size = data.length;
for (int i = 0; i < mixNum; i++)
{
for (int d = 0; d < dimNum; d++)
{
m_means[i][d] = data[(int)(Math.random()*size)];
}
}
// 根據均值獲取分類
int[] types = new int[size];
for (int k = 0; k < size; k++)
{
double max = 0;
for (int i = 0; i < mixNum; i++)
{
double v = 0;
for(int j=0;j<dimNum;j++) {
v += Math.abs(data[k*dimNum+j] - m_means[i][j]);
}
if(v > max) {
max = v;
types[k] = i;
}
}
}
double[] counts = new double[mixNum];
for(int i=0; i<types.length; i++) {
counts[types[i]]++;
}
// 計算先驗概率權重
for (int i = 0; i < mixNum; i++)
{
weights[i] = counts[i] / size;
}
// 計算每個分類的方差
int label = -1;
int[] Label = new int[size];
double[] overMeans = new double[dimNum];
double[] x = new double[dimNum];
for (int i = 0; i < size; i++)
{
for(int j=0;j<dimNum;j++)
x[j]=data[i*dimNum+j];
label=Label[i];
// Count each Gaussian
counts[label]++;
for (int d = 0; d < dimNum; d++)
{
m_vars[label][d] += (x[d] - m_means[types[i]][d]) * (x[d] - m_means[types[i]][d]);
}
// Count the overall mean and variance.
for (int d = 0; d < dimNum; d++)
{
overMeans[d] += x[d];
m_minVars[d] += x[d] * x[d];
}
}
// Compute the overall variance (* 0.01) as the minimum variance.
for (int d = 0; d < dimNum; d++)
{
overMeans[d] /= size;
m_minVars[d] = Math.max(MIN_VAR, 0.01 * (m_minVars[d] / size - overMeans[d] * overMeans[d]));
}
// Initialize each Gaussian.
for (int i = 0; i < mixNum; i++)
{
if (weights[i] > 0)
{
for (int d = 0; d < dimNum; d++)
{
m_vars[i][d] = m_vars[i][d] / counts[i];
// A minimum variance for each dimension is required.
if (m_vars[i][d] < m_minVars[d])
{
m_vars[i][d] = m_minVars[d];
}
}
}
}
System.out.println("=================初始化=================");
for(int i=0; i<mixNum; i++) {
for(int k=0; k<dimNum; k++) {
System.out.println("[" + i + "]: ");
System.out.println("means : " + m_means[i][k]);
System.out.println("var : " + m_vars[i][k]);
System.out.println();
}
}
}
/***
*
* @param sample - 采樣數據點
* @return 該點總概率密度分布可能性
*/
public double getProbability(double[] sample)
{
double p = 0;
for (int i = 0; i < mixNum; i++)
{
p += weights[i] * getProbability(sample, i);
}
return p;
}
/**
* Gaussian Model -> PDF
* @param x - 表示采樣數據點向量
* @param j - 表示對對應的第J個分類的概率密度分布
* @return - 返回概率密度分布可能性值
*/
public double getProbability(double[] x, int j)
{
double p = 1;
for (int d = 0; d < dimNum; d++)
{
p *= 1 / Math.sqrt(2 * 3.14159 * m_vars[j][d]);
p *= Math.exp(-0.5 * (x[d] - m_means[j][d]) * (x[d] - m_means[j][d]) / m_vars[j][d]);
}
return p;
}
public static void main(String[] args) {
GMMProcessor filter = new GMMProcessor(1, 2);
filter.process(samples);
}
}
結構類DataNode
package com.gloomyfish.image.gmm;
public class DataNode {
public int cindex; // cluster
public int index;
public double[] value;
public DataNode(double[] v) {
this.value = v;
cindex = -1;
index = -1;
}
}

這里初始中心均值的方法我是通過隨機數來實現,GMM算法運行結果跟初始化有很大關系,常見初始化中心點的方法是通過K-Means來計算出中心點。大家可以嘗試修改代碼基于K-Means初始化參數,我之所以選擇隨機參數初始,主要是為了省事!
雷鋒網相關閱讀:
25 行 Python 代碼實現人臉檢測——OpenCV 技術教程
手把手教你如何用 OpenCV + Python 實現人臉識別
深度學習之神經網絡特訓班
20年清華大學神經網絡授課導師鄧志東教授,帶你系統(tǒng)學習人工智能之神經網絡理論及應用!
課程鏈接:http://www.mooc.ai/course/65
加入AI慕課學院人工智能學習交流QQ群:624413030,與AI同行一起交流成長
雷峰網版權文章,未經授權禁止轉載。詳情見轉載須知。