sparkmllib源码阅读分类算法4DecisionTree文档格式.docx
- 文档编号:21530423
- 上传时间:2023-01-31
- 格式:DOCX
- 页数:18
- 大小:205.66KB
sparkmllib源码阅读分类算法4DecisionTree文档格式.docx
《sparkmllib源码阅读分类算法4DecisionTree文档格式.docx》由会员分享,可在线阅读,更多相关《sparkmllib源码阅读分类算法4DecisionTree文档格式.docx(18页珍藏版)》请在冰豆网上搜索。
LearningNode:
决策树训练时结点的表示类LearningNode,在训练完成后通过LearningNode.toNode方法,将其转变为InternalNode或者LeafNode。
说一下几个参数的意思:
prediction:
预测类别或者回归值
impurity:
不纯度,Spark实现了三种不纯度度量方式:
熵、信息增益、残差(适用于回归)。
leftChild、rightChild:
左右子节点
split:
Node在进行预测时,需要用到split存储的结点信息,由split来决定选择左结点还是右结点。
结点分裂信息类Split:
Spark实现了2个结点选取类CategoricalSplit和ContinuousSplit,分别完成分类特征和连续特征下的子结点选取问题。
CategoricalSplit:
将分类特征的属性值集分成2个集合(左集合)和右集合,判断属性值属于哪个集合来决定选取哪个子节点。
ContinuousSplit:
针对连续型特征的子节点选取类,输入的特征值与设定的阀值threshold比较大小,来决定是选取左子节点还是右子结点。
决策树特征选择与分裂:
选择一个合适的特征作为判断节点,可以快速的分类,减少决策树的深度。
决策树的目标就是把数据集按对应的类标签进行分类。
最理想的情况是,通过特征的选择能把不同类别的数据集贴上对应类标签。
特征选择的目标使得分类后的数据集比较纯。
Spark实现了3类数据不纯度度量算法:
Giniimpurity、Entropy、Variance,都继承自Impurity类并覆写了不纯度计算方法calculate。
Giniimpurity:
采用基尼指数来度量数据的不纯度,计算公式如下:
计算代码如下:
[java]viewplaincopyprint?
在CODE上查看代码片派生到我的代码片
<
precode_snippet_id="
2325202"
snippet_file_name="
blog_20170411_1_1999783"
name="
code"
class="
java"
>
overridedefcalculate(counts:
Array[Double],totalCount:
Double):
Double={
if(totalCount==0){
return0
}
valnumClasses=counts.length
//∑Ci=1fi(1−fi)=∑Ci=1fi+∑Ci=1fi*fi,其中前半部分为1实际只需要计算后半部分。
varimpurity=1.0
varclassIndex=0
while(classIndex<
numClasses){
valfreq=counts(classIndex)/totalCount//fi
impurity-=freqfreq
classIndex+=1
impurity
}<
/pre>
Entropyimpurity:
采用熵来度量数据的不纯度,计算公式如下:
//∑Ci=1−filog(fi)
varimpurity=0.0
valclassCount=counts(classIndex)
if(classCount!
=0){
valfreq=classCount/totalCount
impurity-=freqlog2(freq)
Varianceimpurity:
使用残差度量数据不纯度,使用决策树回归问题,计算公式如下:
实现代码:
blog_20170411_20_2034920"
overridedefcalculate(count:
Double,sum:
Double,sumSquares:
if(count==0){
valsquaredLoss=sumSquares-(sum*sum)/count
squaredLoss/count
特征选取的方式是子结点的总数据不纯度小于当前结点的数据不纯度,并且其差值越大越好,即结点的分裂总是朝着数据纯度提高的方向进行:
分裂候选集:
上面提到,Spark的决策树的基本形态是一颗二叉树,那么在每个非叶子结点上,都需要选择特征并将特征值一分为二,并根据样本的特征值的归属来决定样本分配至哪一个子节点。
分裂候选集即是来完成特征值一分为二的过程,和切西瓜那样一刀切下去会有很多种不同的切分类似,分裂候选集也会产生很多种对特征值集不同的切分方法,之后在模型训练时选择一种最优的切法。
分裂候选集是将当前的输入特征的属性值集分成两大属性值集合或者两个区间,如例子中婚姻状态有已婚、未婚、不知,那么可以构造多个两两互斥的属性值集<
已婚>
。
未婚、不知>
、<
已婚、不知>
未婚>
等等。
对于分类型特征,如果特征值有M个可能取值,则可以构造个分裂候选。
如果特征有100个可能值,那么可能的分裂选项就非常的多,搜索起来也很昂贵。
因此有必要减少可能的分裂候选数量,基本方法是将特征值按分裂后的纯度或者与目标类的相关性进行排序,以上为例,假设已婚、未婚、不知分别与目标因变量lable=1的相关性为0.6、0.4、0.2,那么可能的划分是<
已婚、未婚>
不知>
两种。
因此M个可能取值的特征,其进行排序后可能存在的切分点为M-1个。
对于连续特征,需要先对所有的取值进行排序才能寻找可能的切分点。
由于大数据下的值排序是比较昂贵的,因此采用了抽样的方式获得一个特征值子集来构造分裂候选集。
看看该部分的实现代码,分类候选集代码在org.apache.spark.ml.tree.impl.RandomForest中
protected[tree]deffindSplits(
input:
RDD[LabeledPoint],//输入数据
metadata:
DecisionTreeMetadata,//元信息
seed:
Long):
Array[Array[Split]]={
logDebug("
isMulticlass="
+metadata.isMulticlass)
valnumFeatures=metadata.numFeatures//特征数量
valcontinuousFeatures=Range(0,numFeatures).filter(metadata.isContinuous)//得到连续特征的index
//对连续特征分裂构建所需的子样本集,参看该篇"
分裂候选集"
章节
valsampledInput=if(continuousFeatures.nonEmpty){
valrequiredSamples=math.max(metadata.maxBins*metadata.maxBins,10000)//估算抽样数量
valfraction=if(requiredSamples<
metadata.numExamples){////计算抽样率
requiredSamples.toDouble/metadata.numExamples
}else{
1.0
fractionofdatausedforcalculatingquantiles="
+fraction)
//进行无放回抽样该抽样的实现方式可参考
input.sample(withReplacement=false,fraction,newXORShiftRandom(seed).nextInt())
input.sparkContext.emptyRDD[LabeledPoint]
findSplitsBySorting(sampledInput,metadata,continuousFeatures)//找到分裂点
}
privatedeffindSplitsBySorting(
RDD[LabeledPoint],//无放回抽样后的子样本集
DecisionTreeMetadata,//元信息
continuousFeatures:
IndexedSeq[Int]//连续特征的index
):
//这一步是找到连续特征的多个可能分裂点
valcontinuousSplits:
scala.collection.Map[Int,Array[Split]]={
valnumPartitions=math.min(continuousFeatures.length,input.partitions.length)
input
.flatMap(point=>
continuousFeatures.map(idx=>
(idx,point.features(idx))))//得到的RDD是RDD<
features_idx,features_value>
.groupByKey(numPartitions)//<
features_idx,list<
features_value>
.map{case(idx,samples)=>
valthresholds=findSplitsForContinuousFeature(samples,metadata,idx)//连续特征分裂候选的排序、分裂函数
valsplits:
Array[Split]=thresholds.map(thresh=>
newContinuousSplit(idx,thresh))
logDebug(s"
featureIndex=$idx,numSplits=${splits.length}"
)
(idx,splits)
}.collectAsMap()//amapthatcontains<
idx,splits>
//将连续特征和分类特征的分裂点合并并返回
valnumFeatures=metadata.numFeatures
Array[Array[Split]]=Array.tabulate(numFeatures){
caseiifmetadata.isContinuous(i)=>
valsplit=continuousSplits(i)
metadata.setNumSplits(i,split.length)
split
caseiifmetadata.isCategorical(i)&
&
metadata.isUnordered(i)=>
//Unorderedfeatures
//2^(maxFeatureValue-1)-1combinations
valfeatureArity=metadata.featureArity(i)
Array.tabulate[Split](metadata.numSplits(i)){splitIndex=>
valcategories=extractMultiClassCategories(splitIndex+1,featureArity)
newCategoricalSplit(i,categories.toArray,featureArity)
caseiifmetadata.isCategorical(i)=>
//Orderedfeatures
//Splitsareconstructedasneededduringtraining.
Array.empty[Split]
splits
//这一步是找到很多可能的分裂点
private[tree]deffindSplitsForContinuousFeature(
featureSamples:
Iterable[Double],//特征featureIndex的值集合featureSamples
DecisionTreeMetadata,
featureIndex:
Int):
Array[Double]={
require(metadata.isContinuous(featureIndex),
"
findSplitsForContinuousFeaturecanonlybeusedtofindsplitsforacontinuousfeature."
valsplits=if(featureSamples.isEmpty){
Array.empty[Double]
valnumSplits=metadata.numSplits(featureIndex)//分裂数
//getcountforeachdistinctvalue
val(valueCountMap,numSamples)=featureSamples.foldLeft((Map.empty[Double,Int],0)){//(Map.empty[Double,Int],0)是foldLeft函数传入的初始值
case((m,cnt),x)=>
//(m,cnt)已经累加的值,x为新传入的值即featureSamples中的值
(m+((x,m.getOrElse(x,0)+1)),cnt+1)//每个特征值得数量在valueCountMap中key=numSamples中value,valueCountMap中value=该特征值的计数,,numSamples为总的计数
//sortdistinctvalues
valvalueCounts=valueCountMap.toSeq.sortBy(_._1).toArray//根据特征值的大小排序
//ifpossiblesplitsisnotenoughorjustenough,justreturnallpossiblesplits
valpossibleSplits=valueCounts.length-1
if(possibleSplits<
=numSplits){
valueCounts.map(_._1).init//如果特征值数目小于设定值,直接返回所有的特征值作为分裂值
//stridebetweensplits
valstride:
Double=numSamples.toDouble/(numSplits+1)
stride="
+stride)
//iterate`valueCount`tofindsplits
valsplitsBuilder=mutable.ArrayBuilder.make[Double]
varindex=1
//currentCount:
sumofcountsofvaluesthathavebeenvisited
varcurrentCount=valueCounts(0)._2
//这里划分成numSplits个分裂点,并使划分后每箱样本数量均衡
vartargetCount=stride
while(index<
valueCounts.length){
valpreviousCount=currentCount
currentCount+=valueCounts(index)._2//取得当前值的样本数量
valpreviousGap=math.abs(previousCount-targetCount)
valcurrentGap=math.abs(currentCount-targetCount)
//IfaddingcountofcurrentvaluetocurrentCount
//makesthegapbetweencurrentCountandtargetCountsmaller,
//previousvalueisasplitthreshold.
if(previousGap<
currentGap){
splitsBuilder+=valueCounts(index-1)._1
targetCount+=stride
index+=1
splitsBuilder.result()
第二部分:
决策树的整体训练流程
上面介绍了决策树的基本概念、决策树的存储与表示、决策树特征选择算法、特征值分裂候选集等和决策树息息相关的一些概念和算法。
以及在"
spark.mllib源码阅读-bagging方法"
介绍的随机森林的样本子集抽样算法。
决策树的训练过程在上面各个组件的基础上,通过特征值分裂候选集来对特征进行值集合分箱,再在子样本集上重复的进行特征选择算法来选取每个结点的最优特征与特征值划分来构造树的结点,直至满足结点分裂的终止规则。
下面以一个实例开始,来一步步的剖析的决策树的整个训练过程。
以下实例摘自org.apache.spark.examples.mllib.JavaDecisionTreeClassificationExample。
SparkConfsparkConf=newSparkConf().setAppName("
JavaDecisionTreeClassificationExample"
);
JavaSparkContextjsc=newJavaSparkContext(sparkConf);
//Loadandparsethedatafile.
Stringdatapath="
data/mllib/sample_libsvm_data.txt"
;
JavaRDD<
LabeledPoint>
data=MLUtils.loadLibSVMFile(jsc.sc(),datapath).toJavaRDD();
//Splitthedataintotrainingandtestsets(30%heldoutfortesting)
[]splits=data.randomSplit(newdouble[]{0.7,0.3});
trainingData=splits[0];
testData=splits[1];
//Setparameters.
//EmptycategoricalFeaturesInfo
- 配套讲稿:
如PPT文件的首页显示word图标,表示该PPT已包含配套word讲稿。双击word图标可打开word文档。
- 特殊限制:
部分文档作品中含有的国旗、国徽等图片,仅作为作品整体效果示例展示,禁止商用。设计者仅对作品中独创性部分享有著作权。
- 关 键 词:
- sparkmllib 源码 阅读 分类 算法 DecisionTree