基于Spark实现随机森林代码
本文实例为大家分享了基于Spark实现随机森林的具体代码,供大家参考,具体内容如下
publicclassRandomForestClassficationTestextendsTestCaseimplementsSerializable
{
/**
*
*/
privatestaticfinallongserialVersionUID=7802523720751354318L;
classPredictResultimplementsSerializable{
/**
*
*/
privatestaticfinallongserialVersionUID=-168308887976477219L;
doublelabel;
doubleprediction;
publicPredictResult(doublelabel,doubleprediction){
this.label=label;
this.prediction=prediction;
}
@Override
publicStringtoString(){
returnthis.label+":"+this.prediction;
}
}
publicvoidtest_randomForest()throwsJAXBException{
SparkConfsparkConf=newSparkConf();
sparkConf.setAppName("RandomForest");
sparkConf.setMaster("local");
SparkContextsc=newSparkContext(sparkConf);
StringdataPath=RandomForestClassficationTest.class.getResource("/").getPath()+"/sample_libsvm_data.txt";
RDDdataSet=MLUtils.loadLibSVMFile(sc,dataPath);
RDD[]rddList=dataSet.randomSplit(newdouble[]{0.7,0.3},1);
RDDtrainingData=rddList[0];
RDDtestData=rddList[1];
ClassTaglabelPointClassTag=trainingData.elementClassTag();
JavaRDDtrainingJavaData=newJavaRDD(trainingData,labelPointClassTag);
intnumClasses=2;
MapcategoricalFeatureInfos=newHashMap();
intnumTrees=3;
StringfeatureSubsetStrategy="auto";
Stringimpurity="gini";
intmaxDepth=4;
intmaxBins=32;
/**
*1numClasses分类个数为2
*2numTrees表示的是随机森林中树的个数
*3featureSubsetStrategy
*4
*/
finalRandomForestModelmodel=RandomForest.trainClassifier(trainingJavaData,
numClasses,
categoricalFeatureInfos,
numTrees,
featureSubsetStrategy,
impurity,
maxDepth,
maxBins,
1);
JavaRDDtestJavaData=newJavaRDD(testData,testData.elementClassTag());
JavaRDDpredictRddResult=testJavaData.map(newFunction(){
/**
*
*/
privatestaticfinallongserialVersionUID=1L;
publicPredictResultcall(LabeledPointpoint)throwsException{
//TODOAuto-generatedmethodstub
doublepointLabel=point.label();
doubleprediction=model.predict(point.features());
PredictResultresult=newPredictResult(pointLabel,prediction);
returnresult;
}
});
ListpredictResultList=predictRddResult.collect();
for(PredictResultresult:predictResultList){
System.out.println(result.toString());
}
System.out.println(model.toDebugString());
}
}
得到的随机森林的展示结果如下:
TreeEnsembleModelclassifierwith3trees Tree0: If(feature435<=0.0) If(feature516<=0.0) Predict:0.0 Else(feature516>0.0) Predict:1.0 Else(feature435>0.0) Predict:1.0 Tree1: If(feature512<=0.0) Predict:1.0 Else(feature512>0.0) Predict:0.0 Tree2: If(feature377<=1.0) Predict:0.0 Else(feature377>1.0) If(feature455<=0.0) Predict:1.0 Else(feature455>0.0) Predict:0.0
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。