EM算法及其推广(三)
这一部分我们看一个简单的示例。
数据
在这里我们模拟两个正态分布的均值预测。
产生训练数据的程序如下:
# 指定k个高斯分布參数。这里指定k=2。注意2个高斯分布具有同样均方差Sigma,均值分别为Mu1,Mu2。
def ini_data(Sigma,Mu1,Mu2,k,N):
global X #最终的训练数据集
global Mu #对想要预测的均值进行初始化
global Expectations #期望
X = np.zeros((1,N)) #存储生成的训练数据
Mu = np.random.random(2) #随机生成两个值
Expectations = np.zeros((N,k)) #生成N*k的矩阵
for i in xrange(0,N): #生成训练样本
if np.random.random(1) > 0.5:
X[0,i] = np.random.normal()*Sigma + Mu1
else:
X[0,i] = np.random.normal()*Sigma + Mu2
if isdebug:
print("***********")
print("初始观測数据X:")
print(X)
实现
# EM算法:步骤1,计算E[zij]
def e_step(Sigma,k,N):
global Expectations
global Mu
global X
for i in range(0,N):
Denom = 0
for j in range(0,k):
Denom += math.exp((-1/(2*(float(Sigma**2))))*(float(X[0,i]-Mu[j]))**2)
for j in range(0,k):
Numer = math.exp((-1/(2*(float(Sigma**2))))*(float(X[0,i]-Mu[j]))**2)
Expectations[i,j] = Numer / Denom
if isdebug:
print("***********")
print("隐藏变量E(Z):")
print(Expectations)
# EM算法:步骤2,求最大化E[zij]的參数Mu
def m_step(k,N):
global Expectations
global X
for j in range(0,k):
Numer = 0
Denom = 0
for i in range(0,N):
Numer += Expectations[i,j]*X[0,i]
Denom +=Expectations[i,j]
Mu[j] = Numer / Denom
# 算法迭代iter_num次。或达到精度Epsilon停止迭代
def run(Sigma,Mu1,Mu2,k,N,iter_num,Epsilon):
ini_data(Sigma,Mu1,Mu2,k,N)
print("初始<u1,u2>:", Mu)
for i in range(iter_num):
Old_Mu = copy.deepcopy(Mu)
e_step(Sigma,k,N)
m_step(k,N)
print("结果:",i,Mu)
if sum(abs(Mu-Old_Mu)) < Epsilon:
break
结果与检验
if __name__ == '__main__':
run(6,40,20,2,1000,1000,0.0001)
plt.hist(X[0,:],50)
plt.show()
最终运行结果如下所示:
初始观測数据X:
[[27.82508171 37.97597088 35.54020456 39.58627266 38.96730089 41.60105752
44.50692551 17.84718108 37.2103319 26.73458374 22.38552847 17.31929058
42.39008536 12.72705313 49.51154153 38.0554498 16.4623559 13.10063848
41.97315924 14.47106481 51.34841989 21.33792679 20.61216046 47.39830018
48.56275845 40.32628769 56.58113282 17.00858226 30.25188976 38.29119125
19.60098559 17.31995501 36.43074813 27.99847337 29.08375884 20.7708197
38.77496001 43.45446893 40.6674319 15.6923042 44.83199422 19.74667502
19.51540587 20.39477171 29.08708132 35.09968415 21.54355044 14.98601529
22.14301675 14.27206594 39.81949954 30.36900525 17.39542403 19.01650668
39.01620722 48.95382074 32.3066109 15.20510389 49.3978413 29.09091392
37.0973858 36.22156032 37.40889468 19.56498871 44.06723656 41.23102871
23.75506267 24.18407741 51.75666446 21.53417618 43.02513681 46.08165857
22.08217831 18.3137823 15.79003405 46.45496305 46.96101949 43.82361756
34.04859038 16.12533626 20.21203908 30.37877668 43.19275303 35.00993144
17.31302868 29.45903541 38.1717563 13.76932878 31.08765897 37.68976793
33.30295358 35.34149067 41.71318355 14.28713483 39.47944411 33.68213043
32.02776855 36.00482774 45.95440244 42.34969177 23.48939629 33.20148935
29.58815148 15.06327019 38.04466291 18.95309459 14.35020127 47.10957255
23.89756536 17.57743914 25.85476464 35.91240399 18.9455467 25.4891508
38.93484628 38.87652084 39.2405994 19.59619942 25.00010953 28.8421501
9.8289576 22.07534247 16.50746819 31.79764536 42.83883374 35.59364921
38.11083733 42.02796431 40.19015075 25.69103365 23.98644454 23.3916633
14.248387 18.80229762 29.67456271 37.69995493 43.07438618 12.1742234
18.99473377 12.60679542 16.89640117 17.55191902 43.67440012 21.63408961
35.88694797 21.20874464 39.03020929 38.41191146 20.11382617 17.61566874
24.4358113 15.07599294 53.55098315 32.2229376 45.08096447 19.49549889
22.20434653 37.70955694 17.55522933 46.40140046 51.870968 19.89297138
42.63946024 24.94625245 42.02131055 7.84363135 34.01729378 39.72858648
39.46524754 41.42752298 37.56873346 37.47650575 30.74327678 43.76609305
20.24614042 39.1911829 45.03163237 16.0870995 39.76399726 37.07595371
21.09723263 31.4773112 41.85527947 37.85087884 18.94671457 15.41151199
37.19069693 21.15383135 14.94238155 36.10570329 24.50035672 35.73184822
25.24066728 27.02606348 38.41656046 17.79454704 41.43228551 43.88834941
34.86677011 16.45386174 39.80072975 48.54189893 39.81593709 18.41352672
54.73030662 40.88121863 20.28956989 52.02700048 35.29503261 36.03933022
7.29181241 11.69032428 26.53022177 45.3832638 16.31204044 22.55075318
14.36759708 15.47432594 25.52854982 28.71630804 44.7947991 18.71877625
43.64165052 44.27609426 40.59638757 22.23426836 20.60405471 31.49985451
18.8310789 32.96252078 28.3156934 17.65298185 35.89947821 15.91719791
40.99469103 43.50035861 18.60588509 42.26927705 45.99059141 22.44777094
30.52078168 41.02806518 9.22464048 42.09717431 15.08344954 57.09087042
36.29105079 37.61719816 22.25851295 10.96762983 38.38163504 38.80430437
5.31501209 41.75280206 30.90611661 45.5032132 26.14461419 47.02641941
39.72657424 36.58768946 44.60809551 33.39027116 16.93301205 18.00195304
38.1758827 14.12878437 22.36755714 52.39161043 45.07995899 20.35994259
8.98005573 39.68008571 35.3306441 26.32336335 32.00585095 37.76026034
25.87125175 39.33713896 19.88877852 15.89982611 28.04262797 29.10457422
49.48105339 24.13183234 32.74522101 34.44995827 16.82114441 15.72911089
51.84805385 38.21256008 27.13704218 39.59870559 38.96097556 29.55208986
49.13727646 48.01023153 35.13204597 26.59858395 40.92077199 44.86132058
40.42901485 21.32296923 35.28182667 40.29022712 41.59656595 20.39200878
41.29418757 18.76751766 33.25331971 22.63228045 15.81097646 45.37835832
39.01048359 13.99595713 43.19720792 7.72367324 48.32219916 15.54272228
47.67878049 42.3912899 50.66266255 44.34507689 50.80407595 22.52986822
46.06183894 43.48064327 35.94991081 24.30388331 20.7453102 42.17195598
22.94109017 14.12296693 49.68684673 20.42988421 39.67389602 12.37210307
25.43067752 31.50432996 45.34862939 29.66816409 14.08755607 33.41123654
47.87173984 21.65868339 38.97454617 47.98925662 18.54304447 24.92265693
36.36228076 35.35036668 38.61002139 18.0860735 19.3633958 15.89130537
7.11616844 36.29398065 26.44092678 48.02508795 40.93649313 44.13885668
42.45463177 44.45971774 37.10274351 20.35955514 37.69453373 32.78798442
16.54796461 20.7439167 29.92891167 23.27756589 43.52142664 35.27127152
42.24457432 45.24268419 18.29287734 43.43541818 18.01720786 51.3652603
15.9379641 50.73462423 41.02253061 46.38463099 46.26464705 33.64568754
42.51408324 39.14850928 34.36084733 52.00526528 5.88341431 41.60643383
14.13196661 42.90963082 32.82364936 38.14589537 23.36796869 45.03163016
31.32210371 34.2644015 34.51096367 26.22829718 40.28721232 13.47051907
32.45127187 34.63450477 29.79027483 33.93951218 56.76205705 39.61869863
36.46873725 25.05402162 39.93448144 21.02643074 44.38618813 35.34206804
21.31557016 44.35131056 35.22787305 25.78058197 36.63191 21.16490575
18.18925306 21.59372451 10.57777289 21.99065975 18.05094352 20.26349026
27.71497441 43.630469 40.20793402 13.68424697 21.33961953 11.22512126
48.87054302 15.12661662 48.31481681 21.24923944 42.47867691 42.21226743
26.53428929 40.67694765 24.38390017 39.78819889 24.23897441 52.74139446
49.91110614 40.68395702 38.4019648 23.60147304 34.70874853 9.76586745
17.32179856 37.41070552 45.67120981 31.48516434 30.47878988 15.11533931
10.73092402 33.66687199 28.65206295 18.47470172 19.0763261 33.58387568
24.43618272 19.29386897 14.82167457 41.23754884 48.30239251 30.80048163
30.93536142 39.17318429 46.53169227 26.1006547 36.48909197 40.440806
50.20983006 6.31883527 12.13501331 45.37287598 28.5149891 41.37227771
21.95392989 37.10135223 25.4521454 39.6219183 26.69061824 22.71666613
42.43192508 50.08133615 24.07667755 43.4532035 48.83012955 14.05680729
16.53145374 40.47069782 29.6114959 28.60687811 12.64532036 23.19991157
39.08576892 23.04890768 44.18486293 38.920376 26.34674938 44.52768027
21.89312157 49.90128 18.39229895 39.72158998 38.91330495 47.78073316
12.42466893 18.42326987 14.62815125 14.6076352 30.06060689 23.40469327
39.51375538 17.77795124 16.73764303 39.66714503 32.39609196 25.21189579
18.51746569 18.77661425 25.8155471 37.40893172 8.83413918 30.75581678
19.98458347 26.00411229 32.68196323 35.10694309 47.93841104 44.69363435
17.60860722 17.92575616 35.94388916 18.69783375 21.37421542 19.10941516
43.41737504 36.66716387 24.66677505 7.31782448 41.10154371 42.11681949
43.77607398 41.0578741 37.04937228 19.322137 25.07479335 12.6505293
28.38942228 34.8649858 11.53591988 35.3526631 12.96158256 40.32412886
39.28315936 44.62761402 42.49805934 34.21414541 24.89697877 9.31049864
14.91216216 15.97551475 23.60469889 17.32473523 42.09344977 31.69328171
21.35102071 40.92200856 16.91979482 22.33599679 38.19366769 30.86513409
18.6451735 25.3947045 15.03365557 40.12043856 33.55935135 18.45316858
21.67447831 18.99724418 13.4342408 28.04046538 27.53612199 40.93485868
30.52838191 46.10337729 10.51514761 32.77127393 40.30040117 43.90687785
29.14984838 20.71267093 25.17460936 27.69083702 44.16021871 29.01014386
17.97951553 12.78621133 23.9928398 41.64867037 37.48174958 17.65500555
41.008314 40.2864353 36.86430586 14.56410234 23.56087173 24.2462363
58.3481558 34.51802301 37.33050191 32.20490846 16.96125721 46.64365711
25.97316648 52.37693374 41.1520578 21.38912655 36.43711407 13.19271057
24.46429536 18.66082221 43.29136046 22.97359162 20.17175502 17.37711392
6.92121256 28.33564963 38.43896407 16.15134353 36.18364041 38.58258903
38.69435858 45.61364977 39.47432926 27.03284364 23.84784369 16.02889099
20.41924305 9.7812151 26.25487509 10.85208225 42.26877546 18.78097034
11.59296901 24.2732992 21.47047164 45.41588856 18.10928557 32.38275969
26.61354068 37.79590718 20.14479519 18.01738305 45.01705147 9.16830473
23.94037404 22.37004421 34.057725 43.1964883 14.35717223 42.19255777
51.38189932 14.12467692 31.524625 41.39167429 40.70645735 20.53365395
35.57735785 33.24231646 14.84561636 22.84375471 23.08336156 38.65029935
38.8538061 43.06849383 35.43093218 19.42276207 41.10598065 31.78679482
26.98841717 23.79175849 21.64037543 12.23346342 36.97727332 29.10774063
44.72432092 21.67427689 11.92450442 27.12066246 51.80914374 38.0895041
35.5053405 34.29507073 25.73324019 42.55959584 36.20965954 39.41186157
25.97712065 38.83252096 35.91646675 45.5747211 41.38440745 24.73821561
52.01779132 25.7166098 31.36513783 18.72351744 29.14095844 21.51909294
25.99910155 46.57924429 17.18040639 18.75470316 21.54230454 41.76687103
11.27774029 42.00648747 35.95079273 18.88443164 9.13719109 47.74701137
31.59600606 4.96650488 12.34059601 7.67971171 17.99235487 15.46688071
43.45947645 10.43254121 43.80399117 41.07151077 46.29868278 11.81018481
13.57901907 43.13738591 46.29289225 28.01801757 37.15856223 17.81437588
43.59669532 37.96217461 43.06789467 42.85189764 22.0364769 45.08305799
22.72201217 16.71540817 18.17781912 42.24368067 46.41879967 29.82920193
41.28146516 26.21034716 20.96371233 16.0217808 27.02773844 49.19804156
18.26963876 39.69783495 22.99510595 32.16189711 38.03414758 9.58839628
36.64003693 23.48967321 20.07334478 25.08305685 20.9244944 15.44374043
24.30048967 17.66954304 10.78184431 27.57815549 32.53679069 44.8949459
38.80086919 18.71433253 34.43055433 9.48920882 47.28601596 17.41575486
18.13441959 26.90502847 14.10120548 34.53298101 24.03832328 55.36588886
15.41315049 28.96511357 52.06325974 37.15863206 41.16012355 9.23449111
21.80448989 9.77081873 13.23109594 21.18105792 21.72848835 17.1679405
35.35438026 35.46029275 27.16808705 10.35566489 37.06140497 16.73441939
50.94206604 32.06281382 14.36902525 54.98157203 18.2117655 12.4976392
36.66189987 15.50597805 38.80011466 30.8147893 40.67350441 45.84660316
50.28401983 29.42854654 18.37863367 22.73409577 18.92969342 21.6922951
18.96943313 18.30291632 38.13181899 45.55286703 40.89666396 35.74621002
8.36533168 19.55667989 39.74022126 32.22772338 10.86284136 29.07521249
16.78246814 48.22972103 16.28266534 32.84638392 21.42332169 26.71585309
41.84995405 36.66288227 37.93646375 25.40754874 35.31271967 18.30815221
40.64083493 20.42916485 30.51996761 15.02850269 20.41400096 17.27382077
15.30953153 19.9908306 23.1749839 17.10746932 16.12077434 38.54445522
11.59226008 12.13558108 44.51403834 39.31751377 25.74900341 14.8751611
44.33667916 35.82841562 31.24090714 40.8821205 19.40707126 17.49213186
28.52839086 46.72630711 11.91832342 21.23145212 21.41952141 33.92337877
14.89336751 44.83728193 28.46945207 42.03703171 33.38096373 22.84013529
23.93915246 21.75762012 15.75110819 32.41130458 18.07728959 26.61345702
40.18648733 41.15382776 19.36665021 27.43986754 38.36471338 17.36159944
20.81724183 42.15246589 43.31710204 46.78219948 36.27339381 36.00442998
46.41965282 33.07857399 23.91788524 44.82693368 52.07620816 22.3585279
42.11768557 39.23103454 20.58113194 17.79599493 17.69926475 18.15538468
34.34378092 17.86166738 19.14955363 16.46235756 37.51574974 41.690254
21.44647069 48.28326381 21.24697672 26.15118796 34.06677783 33.1241191
29.13698911 20.10905589 41.2211431 43.92991665 24.15846121 24.75530809
23.68513737 21.5555631 45.93625348 39.96319438 37.92633927 48.15207567
43.26305589 38.0726257 28.04263704 42.02848951 19.73313496 33.28887636
29.83163463 15.07626436 27.52607613 39.91201099 18.6985502 42.62243876
24.12815945 22.50012732 23.69640859 32.6455529 35.84833881 24.99116212
14.2925821 45.32030922 15.87292631 43.55585758 35.68965166 17.47690128
43.2955071 13.95850965 18.09354387 39.32958487 21.47076189 13.65851729
24.21876978 13.10675192 40.35950187 30.65706959 19.95057346 29.25485208
17.8868581 41.43449704 16.97514034 24.07238235 40.75472182 25.05974823
21.79789874 47.23108515 20.36966043 43.33184144 19.57828951 36.23658191
19.05808244 14.02185343 12.63031705 19.78647765 12.24945604 16.84694144
16.65546736 35.99267841 36.55052193 46.35042294]]
初始<u1,u2>: [0.685575 0.46780956]
隐藏变量E(Z):
[[0.54111361 0.45888639]
[0.55631744 0.44368256]
[0.55267771 0.44732229]
…
[0.55335427 0.44664573]
[0.55418812 0.44581188]
[0.5687829 0.4312171 ]]
结果: 0 [30.43995828 29.60529488]
隐藏变量E(Z):
[[0.48726518 0.51273482]
[0.54596957 0.45403043]
[0.53193785 0.46806215]
…
[0.53454891 0.46545109]
[0.5377654 0.4622346 ]
[0.59352611 0.40647389]]
结果: 1 [31.64177797 28.47635804]
隐藏变量E(Z):
[[0.45104962 0.54895038]
[0.66732698 0.33267302]
[0.61820757 0.38179243]
…
[0.62755278 0.37244722]
[0.63894376 0.36105624]
[0.80728614 0.19271386]]
结果: 2 [35.33010041 24.79664563]
隐藏变量E(Z):
[[0.34188281 0.65811719]
[0.91012762 0.08987238]
[0.83236734 0.16763266]
…
[0.85003694 0.14996306]
[0.86967787 0.13032213]
[0.99155354 0.00844646]]
结果: 3 [39.47916421 20.65525362]
隐藏变量E(Z):
[[2.36425449e-01 7.63574551e-01]
[9.84255340e-01 1.57446598e-02]
[9.45923143e-01 5.40768566e-02]
…
[9.56824527e-01 4.31754728e-02]
[9.67391688e-01 3.26083122e-02]
[9.99799455e-01 2.00545017e-04]]
结果: 4 [40.15963493 19.94854599]
隐藏变量E(Z):
[[2.22456473e-01 7.77543527e-01]
[9.88427632e-01 1.15723677e-02]
[9.56060581e-01 4.39394186e-02]
…
[9.65578197e-01 3.44218028e-02]
[9.74598589e-01 2.54014111e-02]
[9.99893693e-01 1.06307020e-04]]
结果: 5 [40.1934646 19.90211758]
隐藏变量E(Z):
[[2.22211178e-01 7.77788822e-01]
[9.88667763e-01 1.13322367e-02]
[9.56718784e-01 4.32812159e-02]
…
[9.66131906e-01 3.38680938e-02]
[9.75041234e-01 2.49587657e-02]
[9.99897849e-01 1.02151188e-04]]
结果: 6 [40.1930595 19.89759593]
隐藏变量E(Z):
[[2.22407335e-01 7.77592665e-01]
[9.88693451e-01 1.13065491e-02]
[9.56802219e-01 4.31977809e-02]
…
[9.66199526e-01 3.38004740e-02]
[9.75093074e-01 2.49069261e-02]
[9.99898181e-01 1.01819470e-04]]
结果: 7 [40.19222716 19.89654511]
隐藏变量E(Z):
[[2.22496822e-01 7.77503178e-01]
[9.88699921e-01 1.13000788e-02]
[9.56825532e-01 4.31744684e-02]
…
[9.66218036e-01 3.37819642e-02]
[9.75106921e-01 2.48930789e-02]
[9.99898245e-01 1.01755372e-04]]
结果: 8 [40.19187574 19.89617881]
隐藏变量E(Z):
[[2.22531664e-01 7.77468336e-01]
[9.88702218e-01 1.12977821e-02]
[9.56833982e-01 4.31660175e-02]
…
[9.66224719e-01 3.37752808e-02]
[9.75111897e-01 2.48881032e-02]
[9.99898266e-01 1.01734104e-04]]
结果: 9 [40.19174012 19.89604099]
隐藏变量E(Z):
[[2.22544976e-01 7.77455024e-01]
[9.88703084e-01 1.12969157e-02]
[9.56837180e-01 4.31628202e-02]
…
[9.66227246e-01 3.37727536e-02]
[9.75113777e-01 2.48862230e-02]
[9.99898274e-01 1.01726162e-04]]
结果: 10 [40.19168837 19.89598857]
隐藏变量E(Z):
[[2.22550049e-01 7.77449951e-01]
[9.88703414e-01 1.12965861e-02]
[9.56838397e-01 4.31616033e-02]
…
[9.66228208e-01 3.37717918e-02]
[9.75114492e-01 2.48855075e-02]
[9.99898277e-01 1.01723145e-04]]
结果: 11 [40.19166865 19.89596861]
完整代码
https://github.com/canshang/-/blob/master/EM.ipynb