基于Spark实现随机森林代码
本文实例为大家分享了基于Spark实现随机森林的具体代码,供大家参考,具体内容如下
publicclassRandomForestClassficationTestextendsTestCaseimplementsSerializable { /** * */ privatestaticfinallongserialVersionUID=7802523720751354318L; classPredictResultimplementsSerializable{ /** * */ privatestaticfinallongserialVersionUID=-168308887976477219L; doublelabel; doubleprediction; publicPredictResult(doublelabel,doubleprediction){ this.label=label; this.prediction=prediction; } @Override publicStringtoString(){ returnthis.label+":"+this.prediction; } } publicvoidtest_randomForest()throwsJAXBException{ SparkConfsparkConf=newSparkConf(); sparkConf.setAppName("RandomForest"); sparkConf.setMaster("local"); SparkContextsc=newSparkContext(sparkConf); StringdataPath=RandomForestClassficationTest.class.getResource("/").getPath()+"/sample_libsvm_data.txt"; RDDdataSet=MLUtils.loadLibSVMFile(sc,dataPath); RDD[]rddList=dataSet.randomSplit(newdouble[]{0.7,0.3},1); RDDtrainingData=rddList[0]; RDDtestData=rddList[1]; ClassTaglabelPointClassTag=trainingData.elementClassTag(); JavaRDDtrainingJavaData=newJavaRDD(trainingData,labelPointClassTag); intnumClasses=2; MapcategoricalFeatureInfos=newHashMap(); intnumTrees=3; StringfeatureSubsetStrategy="auto"; Stringimpurity="gini"; intmaxDepth=4; intmaxBins=32; /** *1numClasses分类个数为2 *2numTrees表示的是随机森林中树的个数 *3featureSubsetStrategy *4 */ finalRandomForestModelmodel=RandomForest.trainClassifier(trainingJavaData, numClasses, categoricalFeatureInfos, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, 1); JavaRDDtestJavaData=newJavaRDD(testData,testData.elementClassTag()); JavaRDDpredictRddResult=testJavaData.map(newFunction(){ /** * */ privatestaticfinallongserialVersionUID=1L; publicPredictResultcall(LabeledPointpoint)throwsException{ //TODOAuto-generatedmethodstub doublepointLabel=point.label(); doubleprediction=model.predict(point.features()); PredictResultresult=newPredictResult(pointLabel,prediction); returnresult; } }); ListpredictResultList=predictRddResult.collect(); for(PredictResultresult:predictResultList){ System.out.println(result.toString()); } System.out.println(model.toDebugString()); } }
得到的随机森林的展示结果如下:
TreeEnsembleModelclassifierwith3trees Tree0: If(feature435<=0.0) If(feature516<=0.0) Predict:0.0 Else(feature516>0.0) Predict:1.0 Else(feature435>0.0) Predict:1.0 Tree1: If(feature512<=0.0) Predict:1.0 Else(feature512>0.0) Predict:0.0 Tree2: If(feature377<=1.0) Predict:0.0 Else(feature377>1.0) If(feature455<=0.0) Predict:1.0 Else(feature455>0.0) Predict:0.0
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。