Weka28 EM源代码分析.docx
- 文档编号:29561788
- 上传时间:2023-07-24
- 格式:DOCX
- 页数:22
- 大小:204.77KB
Weka28 EM源代码分析.docx
《Weka28 EM源代码分析.docx》由会员分享,可在线阅读,更多相关《Weka28 EM源代码分析.docx(22页珍藏版)》请在冰豆网上搜索。
Weka28EM源代码分析
Weka[28]EM源代码分析
作者:
Koala++/屈伟
EM算法在clusterers下面,提一下是因为我没有想到它竟然在这里,而且它的名字也
太大了点,因为这里它只是与SimpleKMeans结合的算法。
引自AndrewNg的LecturenotesmixturesofGaussiansandtheEMalgorithm:
The
EM-algorithmisalsoreminiscentoftheK-meansclusteringalgorithm,exceptthatinsteadof
“hard”clusterassignmentc(i),weinsteadhavethe“soft”assignmentw_j^(i).SimilartoK-means,
itisalsosusceptibletolocaloptima,soreinitializingatseveraldifferentinitialparametersmaybe
agoodidea。
Soft指的是我们猜测是概率,取值在[0,1]区间,相反,“hard”猜测是指单个最好的猜测,
可以取值在{0,1}或是{1,…,k}。
英文原文:
Theterm“soft”referstoourguessesbeing
probabilitiesandtakingvaluesin[0,1];incontrast,a“hard”guessisonethatrepresentsasingle
bestguess(suchastakingvaluesin{0,1}or{1,…,k})
下面的图来自NgAndrew和BishopChistopher,第一组图K-Means的猜测是两个点,
而第二组图EM是对概率的猜测。
另一点是刚才文中提到的,多个初始化点,在代码中也体现了。
Ng在对EM算法收敛证明之后,解释如下:
Hence,EMcausesthelikelihoodtoconverge
monotonically.InourdescriptionoftheEMalgorithm,wesaidwe'drunituntilconvergence.
Giventheresultthatwejustshowed,onereasonableconvergencetestwouldbetocheckifthe
increaseinl(theta)betweensuccessiveiterationsissmallerthansometoleranceparameter,andto
declareconvergenceifEMisimprovingl(theta)tooslowly.
从buildCluster开始:
if(data.checkForStringAttributes()){
thrownewException("Can'thandlestringattributes!
");
}
m_replaceMissing=newReplaceMissingValues();
Instancesinstances=newInstances(data);
instances.setClassIndex(-1);
m_replaceMissing.setInputFormat(instances);
data=weka.filters.Filter.useFilter(instances,m_replaceMissing);
instances=null;
m_theInstances=data;
//calculateminandmaxvaluesforattributes
m_minValues=newdouble[m_theInstances.numAttributes()];
m_maxValues=newdouble[m_theInstances.numAttributes()];
for(inti=0;i m_minValues[i]=m_maxValues[i]=Double.NaN; } for(inti=0;i updateMinMax(m_theInstances.instance(i)); } ReplaceMissingValues是将缺失值用平均值或中位数代替。 m_minValues和m_maxValues 是每个属性的最小值与最大值数组。 privatevoidupdateMinMax(Instanceinstance){ for(intj=0;j if(! instance.isMissing(j)){ if(Double.isNaN(m_minValues[j])){ m_minValues[j]=instance.value(j); m_maxValues[j]=instance.value(j); }else{ if(instance.value(j) m_minValues[j]=instance.value(j); }else{ if(instance.value(j)>m_maxValues[j]){ m_maxValues[j]=instance.value(j); } } } } } } Double.isNan这里是判断是不是还没有一个真正的属性值来代替过它。 其它的代码就是 找第j个属性的最大值和最小值。 doEM(); //savememory m_theInstances=newInstances(m_theInstances,0); doEM之后就是释放空间了,那么所有的工作都是在doEM中完成的: privatevoiddoEM()throwsException{ m_rr=newRandom(m_rseed); //throwawaynumberstoavoidproblemofsimilarinitialnumbers //fromasimilarseed for(inti=0;i<10;i++) m_rr.nextDouble(); m_num_instances=m_theInstances.numInstances(); m_num_attribs=m_theInstances.numAttributes(); //setDefaultStdDevs(theInstances); //crossvalidatetodeterminenumberofclusters? if(m_initialNumClusters==-1){ if(m_theInstances.numInstances()>9){ CVClusters(); m_rr=newRandom(m_rseed); for(inti=0;i<10;i++) m_rr.nextDouble(); }else{ m_num_clusters=1; } } //fitfulltrainingset EM_Init(m_theInstances); m_loglikely=iterate(m_theInstances,m_verbose); } 丢弃从同一个种子得到的随机数,这个与下面的代码有关? 如果m_initialNumClusters ==-1表明没有指定要聚多少个类,那么要用crossvalidate来决定聚多少个类。 如果样本数 大于9,用CVClusters函数来决定。 如果小于9个样本,就认为就一个类。 EM_Init初始化, 然后迭代,先不去管CVClusters,认为已经指定了m_initialNumClusters,那么先看EM_Init: //runkmeans10timesandchoosebestsolution SimpleKMeansbestK=null; doublebestSqE=Double.MAX_VALUE; for(i=0;i<10;i++){ SimpleKMeanssk=newSimpleKMeans(); sk.setSeed(m_rr.nextInt()); sk.setNumClusters(m_num_clusters); sk.buildClusterer(inst); if(sk.getSquaredError() bestSqE=sk.getSquaredError(); bestK=sk; } } 这里是用不同的随机种子初始化,最后求得一个最好的SimpleKMeans对象。 //initializewithbestk-meanssolution m_num_clusters=bestK.numberOfClusters(); m_weights=newdouble[inst.numInstances()][m_num_clusters]; m_model=newDiscreteEstimator[m_num_clusters][m_num_attribs]; m_modelNormal=newdouble[m_num_clusters][m_num_attribs][3]; m_priors=newdouble[m_num_clusters]; Instancescenters=bestK.getClusterCentroids(); InstancesstdD=bestK.getClusterStandardDevs(); int[][][]nominalCounts=bestK.getClusterNominalCounts(); int[]clusterSizes=bestK.getClusterSizes(); centers是聚类后的所有中心点,stdD是标准差,而nominalCounts第一维大小为所聚类 的个数,第二维属性数,第三级该维的取值数。 for(i=0;i Instancecenter=centers.instance(i); for(j=0;j if(inst.attribute(j).isNominal()){ m_model[i][j]=newDiscreteEstimator(m_theInstances .attribute(j).numValues(),true); for(k=0;k m_model[i][j].addValue(k,nominalCounts[i][j][k]); } }else{ doubleminStdD=(m_minStdDevPerAtt! =null)? m_minStdDevPerAtt[j]: m_minStdDev; doublemean=(center.isMissing(j))? inst.meanOrMode(j) : center.value(j); m_modelNormal[i][j][0]=mean; doublestdv=(stdD.instance(i).isMissing(j))? ((m_maxValues[j]-m_minValues[j])/(2*m_num_clusters)) : stdD.instance(i).value(j); if(stdv stdv=inst.attributeStats(j).numericStats.stdDev; if(Double.isInfinite(stdv)){ stdv=minStdD; } if(stdv stdv=minStdD; } } if(stdv<=0){ stdv=m_minStdDev; } m_modelNormal[i][j][1]=stdv; m_modelNormal[i][j][2]=1.0; } } } 这里DiscreteEstimator是针对离散数据进行统计的一个类,构造函数如下: publicDiscreteEstimator(intnumSymbols,booleanlaplace){ m_Counts=newdouble[numSymbols]; m_SumOfCounts=0; if(laplace){ for(inti=0;i m_Counts[i]=1; } } m_SumOfCounts=(double)numSymbols; } 这里使用了laplace平滑,m_Counts初始为1,也就是平常所见过的公式加上1,并且 SumOfCounts也初始化为取值的个数,也就是公式中分母最后加的那个数。 publicvoidaddValue(doubledata,doubleweight){ m_Counts[(int)data]+=weight; m_SumOfCounts+=weight; } addValue函数很简单就是在第几个取值上,加上相应的权重。 写这么麻烦是因为不能得 到连续值的估计。 minStdD控制精度,平均值是中心点的取值,而stdv就是在SimpleKMeans 中计算出的值。 M_modelNormal[i][j][0]是均值,M_modelNormal[i][j][1]是方差, M_modelNormal[i][j][2]记录的是概率。 for(j=0;j //m_priors[j]+=1.0; m_priors[j]=clusterSizes[j]; } Utils.normalize(m_priors); 通过每个所聚类的大小,算出先验概率。 privatedoubleiterate(Instancesinst,booleanreport)throwsException { inti; doublellkold=0.0; doublellk=0.0; booleanok=false; intseed=m_rseed; intrestartCount=0; while(! ok){ try{ for(i=0;i llkold=llk; llk=E(inst,true); if(i>0){ if((llk-llkold)<1e-6){ break; } } M(inst); } ok=true; }catch(Exceptionex){ } } returnllk; } 可以看到有两种迭代中止的方法,第一种是达到了m_max_iterations,第二种是llk-llkold 小于阈值,llk是loglikelihood的缩写,EM的目标就是最大化它,如果已经接近最优值了, 所以就停止了。 下面就是E: privatedoubleE(Instancesinst,booleanchange_weights)throwsException { doubleloglk=0.0,sOW=0.0; for(intl=0;l Instancein=inst.instance(l); loglk+=in.weight()*logDensityForInstance(in); sOW+=in.weight(); if(change_weights){ m_weights[l]=distributionForInstance(in); } } //reestimatepriors if(change_weights){ estimate_priors(inst); } returnloglk/sOW; } 这里logDensityForInstance的代码: publicdoublelogDensityForInstance(Instanceinstance)throwsException { double[]a=logJointDensitiesForInstance(instance); doublemax=a[Utils.maxIndex(a)]; doublesum=0.0; for(inti=0;i sum+=Math.exp(a[i]-max); } returnmax+Math.log(sum); } logJointDensitiesForInstance: publicdouble[]logJointDensitiesForInstance(Instanceinst) throwsException{ double[]weights=logDensityPerClusterForInstance(inst); double[]priors=clusterPriors(); for(inti=0;i if(priors[i]>0){ weights[i]+=Math.log(priors[i]); }else{ thrownewIllegalArgumentException("Clusterempty! "); } } returnweights; } logDensityPerClusterForInstance: publicdouble[]logDensityPerClusterForInstance(Instanceinst) throwsException{ inti,j; doublelogprob; double[]wghts=newdouble[m_num_clusters]; m_replaceMissing.input(inst); inst=m_replaceMissing.output(); for(i=0;i logprob=0.0; for(j=0;j if(! inst.isMissing(j)){ if(inst.attribute(j).isNominal()){ logprob+=Math.log(m_model[i][j]. getProbability(inst.value(j))); }else{//numericattribute logprob+=logNormalDens(inst.value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]); } } } wghts[i]=logprob; } returnwghts; } 对于离散型属性,这里计算它的概率的对数,它的概率计算很简单: publicdoublegetProbability(doubledata){ if(m_SumOfCounts==0){ return0; } return(double)m_Counts[(int)data]/m_SumOfCounts; } 就是Laplace平滑后的概率,而对于连续属性: privatestaticdoublem_normConst=Math.log(Math.sqrt(2*Math.PI)); privatedoublelogNormalDens(doublex,doublemean,doublestdDev){ doublediff=x-mean; return-(diff*diff/(2*stdDev*stdDev))-m_normConst -Math.log(stdDev); } Diff就是高斯分布(正态分布)中的x-mean,而下面的那一长串就是对高斯分布的公式对 数化得到的。 回到logJointDensitiesForInstance中,clusterPriors的代码如下: publicdouble[]clusterPriors(){ double[]n=newdouble[m_priors.length]; System.arraycopy(m_priors,0,n,0,n.length); returnn; } 只是复制一下,而在clusterPriors后,weights[i]+=Math.log(priors[i])还是取对数后,可 以展开来,也就是P(xi|zi)P(zi)的P(zi),log(P(xi|zi)P(zi))=log(P(xi|zi))+log(P(zi))。 再回到logDensityForInstance中,这里的Math.exp(a[i]-max)这个看起来奇怪,公式里没 有a[i]-max,这里可以把max这个量想成要用别的a[i]组合得到的,因为sum(a[i])=1,可以 参考一下AndrewNg的lecturen
- 配套讲稿:
如PPT文件的首页显示word图标,表示该PPT已包含配套word讲稿。双击word图标可打开word文档。
- 特殊限制:
部分文档作品中含有的国旗、国徽等图片,仅作为作品整体效果示例展示,禁止商用。设计者仅对作品中独创性部分享有著作权。
- 关 键 词:
- Weka28 EM源代码分析 EM 源代码 分析