最近在看李航的《统计学习方法》一书,关于EM算法部分收集了些资料进行了学习,做了些混合高斯的模拟,下面分三个部分介绍下相关内容:1)EM算法原理,2)混合高斯推导,3)相关代码和结果
一、EM算法原理
EM算法推导中一个重要的概念是Jensen不等式。其表述为:如果为凸函数(
),则有
,当且仅当
的时候不等式两边等号才成立。
如果概率模型只针对观测样本,那么根据
的观测值,可以通过极大似然或贝叶斯估计法估计其参数
。但是,如果概率模型不仅包含观测样本
,还含有隐变量
(无法观测其值),这时就需要EM算法来估计隐变量
和观测样本
的模型参数,也可以认为EM算法是含有隐变量的极大似然估计法。
观测数据,隐变量
,用
表示样本
的隐变量分布
则似然函数可以表示为
可以看成是
的期望由Jensen不等式可得
这里log函数是凹函数,故有
如果需要确定下界,则需要等号成立,则有,令
为常数则有
已知了后就可以调整
来优化下界
因为每一次迭代均是极大值故有
,因此迭代过程中
单调递增,故EM算法收敛
二、混合高斯
高斯混合概率分布如下
隐变量为,即第
个观测值是否来自第
个高斯模型的概率,取0或1,也有的资料写成
,其实想表达的意思是一样的。
这里即为第一部分推导中的
目标函数
利用极值法有:
(1)
考虑约束条件 由拉格朗日乘数法
(2)
(3)
求解(1)(2)(3)即可得到混合高斯参数的值
参数的求解如下:
三、以下是混合高斯的实验,opencv2.4.9
// HelloOpenCV.cpp : Defines the entry point for the console application.
//
#include "stdafx.h"
#include <opencv2/opencv.hpp>
#include <iostream>
#include <math.h>
#define COUNT 4
#define HIST_SIZE 256
#define EPS 1e-32
using namespace std;
using namespace cv;
void Gaussian_fun(double *u, double *delta, double *p, unsigned char grey)
{
int i;
for (i = 0; i < COUNT; i++)
{
p[i] = (0.39894228*exp(-pow(grey - u[i], 2) / 2 / (delta[i] + EPS))) / sqrt(delta[i] + EPS);
}
}
void Drawhist(CvHistogram *hist, IplImage *hist_img, CvScalar scalar)
{
int i;
float MaxValue;
double probality, probality_old;
CvSize size = cvGetSize(hist_img);
cvGetMinMaxHistValue(hist, 0, &MaxValue, 0);
probality_old = cvGetReal1D(hist->bins, 0);
probality_old = cvRound(HIST_SIZE*probality_old / size.height/size.width);
for (i = 1; i < 256; i++)
{
probality = cvGetReal1D(hist->bins, i);
probality = cvRound(HIST_SIZE*probality / size.height/size.width);
cvLine(hist_img, cvPoint(i - 1, 1.5*(128 - probality_old)), cvPoint(i, 1.5*(128 - probality)), scalar);
probality_old = probality;
}
}
double Distance(void *new_mat, void *old_mat)
{
CvMat *mat1 = cvCreateMat(4, 1, CV_32FC1);
CvMat *mat2 = cvCreateMat(4, 1, CV_32FC1);
cvSetData(mat1, new_mat, mat1->step);
cvSetData(mat1, new_mat, mat1->step);
return cvNorm(mat1, mat2, CV_L2, 0);
}
void EM_GMM(IplImage *img, IplImage *hist_img)
{
int i, j, iter = 0;
unsigned char grey;
double alpha[COUNT] = { 0.25, 0.25, 0.25, 0.25 }, delta[COUNT] = { 20, 20, 20, 20 }, u[COUNT] = { 50, 100, 150, 200 };
double u_old[COUNT] = { 0 }, alpha_old[COUNT] = { 0 }, delta_old[COUNT] = { 0 };
double p[COUNT] = { 0 };
CvSize size = cvGetSize(img);
double sum_p, sum_gamma[COUNT] = { 0 }, sum_gammay[COUNT] = { 0 }, sum_gammayy[COUNT] = { 0 }, gamma[COUNT] = { 0 };
while (iter < 1000 && Distance(alpha, alpha_old) > 0.01 && Distance(u, u_old) > 0.01 && Distance(delta, delta_old) > 0.01)
{
memset(gamma, 0, sizeof(gamma));
memset(sum_gamma, 0, sizeof(sum_gamma));
memset(sum_gammay, 0, sizeof(sum_gamma));
memset(sum_gammayy, 0, sizeof(sum_gamma));
for (i = 0; i < size.height; i++)
{
for (j = 0; j < size.width; j++)
{
grey = img->imageData[i*size.width + j];
Gaussian_fun(u, delta, p, grey);
sum_p = alpha[0] * p[0] + alpha[1] * p[1] + alpha[2] * p[2] + alpha[3] * p[3] + EPS;
gamma[0] = alpha[0] * p[0] / sum_p;
gamma[1] = alpha[1] * p[1] / sum_p;
gamma[2] = alpha[2] * p[2] / sum_p;
gamma[3] = alpha[3] * p[3] / sum_p;
sum_gamma[0] += gamma[0];
sum_gamma[1] += gamma[1];
sum_gamma[2] += gamma[2];
sum_gamma[3] += gamma[3];
sum_gammay[0] += gamma[0] * grey;
sum_gammay[1] += gamma[1] * grey;
sum_gammay[2] += gamma[2] * grey;
sum_gammay[3] += gamma[3] * grey;
sum_gammayy[0] += gamma[0] * grey * grey;
sum_gammayy[1] += gamma[1] * grey * grey;
sum_gammayy[2] += gamma[2] * grey * grey;
sum_gammayy[3] += gamma[3] * grey * grey;
}
}
for (i = 0; i < 4; i++)
{
alpha_old[i] = alpha[i];
u_old[i] = u[i];
delta_old[i] = delta[i];
alpha[i] = sum_gamma[i] / size.height / size.width;
u[i] = sum_gammay[i] / (sum_gamma[i]+EPS);
delta[i] = (sum_gammayy[i] - 2 * u[i] * sum_gammay[i] + u[i] * u[i] * sum_gamma[i]) / (sum_gamma[i]+EPS);
}
iter++;
}
int sizes = 256;
float range[2] = { 0, 255 };
float *ranges = range;
CvHistogram *hist;
hist = cvCreateHist(1, &sizes, CV_HIST_ARRAY, &ranges, 1);
for (i = 0; i < 256; i++)
{
Gaussian_fun(u, delta, p, i);
sum_p = alpha[0] *p[0] + alpha[1] * p[1] + alpha[2] * p[2] + alpha[3] * p[3];
// cout << sum_p << endl;
cvSetReal1D(hist->bins,i,cvRound(sum_p * size.width*size.height));
}
Drawhist(hist, hist_img, cvScalar(255,0,0,0));
}
int main(int argc, const char* argv[])
{
IplImage *img;
img = cvLoadImage("..\\starry_night.jpg", 1);
IplImage *imgRed = cvCreateImage(cvGetSize(img), 8, 1);
IplImage *imgGreen = cvCreateImage(cvGetSize(img), 8, 1);
IplImage *imgBlue = cvCreateImage(cvGetSize(img), 8, 1);
cvSplit(img, imgRed, imgGreen, imgBlue, 0);
namedWindow("img", CV_WINDOW_AUTOSIZE);
cvShowImage("img", img);
waitKey(0);
int sizes = 256;
float range[] = { 0, 255 };
float*ranges[] = { range };
CvHistogram *hist = cvCreateHist(1, &sizes, CV_HIST_ARRAY, ranges, 1);
cvCalcHist(&imgRed, hist, 0, 0);
IplImage *hist_img = cvCreateImage(cvSize(256, 256), IPL_DEPTH_8U, 3);
Drawhist(hist, hist_img, cvScalar(0, 0, 255, 0));
cvClearHist(hist);
EM_GMM(imgRed, hist_img);
cvNamedWindow("EM&GMM");
cvShowImage("EM&GMM", hist_img);
waitKey(0);
cvDestroyAllWindows();
return 0;
}
ps: 红色为图像的直方图,蓝色为拟合曲线
参考文献
[1]李航 统计学习方法
[2]Andrew.Ng MachineLearning课件
[3]JerryLeadhttp://www.cnblogs.com/jerrylead/archive/2011/04/06/2006936.html