利用python中的matplotlib打印混淆矩阵实例
前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看。。
代码:
importitertools importmatplotlib.pyplotasplt importnumpyasnp defplot_confusion_matrix(cm,classes, normalize=False, title='Confusionmatrix', cmap=plt.cm.Blues): """ Thisfunctionprintsandplotstheconfusionmatrix. Normalizationcanbeappliedbysetting`normalize=True`. """ ifnormalize: cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis] print("Normalizedconfusionmatrix") else: print('Confusionmatrix,withoutnormalization') print(cm) plt.imshow(cm,interpolation='nearest',cmap=cmap) plt.title(title) plt.colorbar() tick_marks=np.arange(len(classes)) plt.xticks(tick_marks,classes,rotation=45) plt.yticks(tick_marks,classes) fmt='.2f'ifnormalizeelse'd' thresh=cm.max()/2. fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])): plt.text(j,i,format(cm[i,j],fmt), horizontalalignment="center", color="white"ifcm[i,j]>threshelse"black") plt.tight_layout() plt.ylabel('Truelabel') plt.xlabel('Predictedlabel') plt.show() #plt.savefig('confusion_matrix',dpi=200) cnf_matrix=np.array([ [4101,2,5,24,0], [50,3930,6,14,5], [29,3,3973,4,0], [45,7,1,3878,119], [31,1,8,28,3936], ]) class_names=['Buildings','Farmland','Greenbelt','Wasteland','Water'] #plt.figure() #plot_confusion_matrix(cnf_matrix,classes=class_names, #title='Confusionmatrix,withoutnormalization') #Plotnormalizedconfusionmatrix plt.figure() plot_confusion_matrix(cnf_matrix,classes=class_names,normalize=True, title='Normalizedconfusionmatrix')
在放矩阵位置,放一下你的混淆矩阵就可以,当然可视化混淆矩阵这一步也可以直接在模型运行中完成。
补充知识:混淆矩阵(Confusionmatrix)的原理及使用(scikit-learn和tensorflow)
原理
在机器学习中,混淆矩阵是一个误差矩阵,常用来可视化地评估监督学习算法的性能.混淆矩阵大小为(n_classes,n_classes)的方阵,其中n_classes表示类的数量.这个矩阵的每一行表示真实类中的实例,而每一列表示预测类中的实例(Tensorflow和scikit-learn采用的实现方式).也可以是,每一行表示预测类中的实例,而每一列表示真实类中的实例(ConfusionmatrixFromWikipedia中的定义).通过混淆矩阵,可以很容易看出系统是否会弄混两个类,这也是混淆矩阵名字的由来.
混淆矩阵是一种特殊类型的列联表(contingencytable)或交叉制表(crosstabulationorcrosstab).其有两维(真实值"actual"和预测值"predicted"),这两维都具有相同的类("classes")的集合.在列联表中,每个维度和类的组合是一个变量.列联表以表的形式,可视化地表示多个变量的频率分布.
使用混淆矩阵(scikit-learn和Tensorflow)
下面先介绍在scikit-learn和tensorflow中计算混淆矩阵的API(ApplicationProgrammingInterface)接口函数,然后在一个示例中,使用这两个API函数.
scikit-learn混淆矩阵函数sklearn.metrics.confusion_matrixAPI接口
skearn.metrics.confusion_matrix( y_true,#array,Goundtrue(correct)targetvalues y_pred,#array,Estimatedtargetsasreturnedbyaclassifier labels=None,#array,Listoflabelstoindexthematrix. sample_weight=None#array-likeofshape=[n_samples],Optionalsampleweights )
在scikit-learn中,计算混淆矩阵用来评估分类的准确度.
按照定义,混淆矩阵C中的元素Ci,j等于真实值为组i,而预测为组j的观测数(thenumberofobservations).所以对于二分类任务,预测结果中,正确的负例数(truenegatives,TN)为C0,0;错误的负例数(falsenegatives,FN)为C1,0;真实的正例数为C1,1;错误的正例数为C0,1.
如果labels为None,scikit-learn会把在出现在y_true或y_pred中的所有值添加到标记列表labels中,并排好序.
Tensorflow混淆矩阵函数tf.confusion_matrixAPI接口
tf.confusion_matrix( labels,#1-DTensorofreallabelsfortheclassificationtask predictions,#1-DTensorofpredictionsforagivenclassification num_classes=None,#Thepossiblenumberoflabelstheclassificationtaskcanhave dtype=tf.int32,#Datatypeoftheconfusionmatrix name=None,#Scopename weights=None,#AnoptionalTensorwhoseshapematchespredictions )
Tensorflowtf.confusion_matrix中的num_classes参数的含义,与scikit-learnsklearn.metrics.confusion_matrix中的labels参数相近,是与标记有关的参数,表示类的总个数,但没有列出具体的标记值.在Tensorflow中一般是以整数作为标记,如果标记为字符串等非整数类型,则需先转为整数表示.如果num_classes参数为None,则把labels和predictions中的最大值+1,作为num_classes参数值.
tf.confusion_matrix的weights参数和sklearn.metrics.confusion_matrix的sample_weight参数的含义相同,都是对预测值进行加权,在此基础上,计算混淆矩阵单元的值.
使用示例
#!/usr/bin/envpython #-*-coding:utf8-*- """ Author:klchang Description: Asimpleexamplefortf.confusion_matrixandsklearn.metrics.confusion_matrix. Date:2018.9.8 """ from__future__importprint_function importtensorflowastf importsklearn.metrics y_true=[1,2,4] y_pred=[2,2,4] #Buildgraphwithtf.confusion_matrixoperation sess=tf.InteractiveSession() op=tf.confusion_matrix(y_true,y_pred) op2=tf.confusion_matrix(y_true,y_pred,num_classes=6,dtype=tf.float32,weights=tf.constant([0.3,0.4,0.3])) #Executethegraph print("confusionmatrixintensorflow:") print("1.default:\n",op.eval()) print("2.customed:\n",sess.run(op2)) sess.close() #Usesklearn.metrics.confusion_matrixfunction print("\nconfusionmatrixinscikit-learn:") print("1.default:\n",sklearn.metrics.confusion_matrix(y_true,y_pred)) print("2.customed:\n",sklearn.metrics.confusion_matrix(y_true,y_pred,labels=range(6),sample_weight=[0.3,0.4,0.3]))
以上这篇利用python中的matplotlib打印混淆矩阵实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。