利用python中的matplotlib打印混淆矩阵实例
前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看。。
代码:
importitertools
importmatplotlib.pyplotasplt
importnumpyasnp
defplot_confusion_matrix(cm,classes,
normalize=False,
title='Confusionmatrix',
cmap=plt.cm.Blues):
"""
Thisfunctionprintsandplotstheconfusionmatrix.
Normalizationcanbeappliedbysetting`normalize=True`.
"""
ifnormalize:
cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]
print("Normalizedconfusionmatrix")
else:
print('Confusionmatrix,withoutnormalization')
print(cm)
plt.imshow(cm,interpolation='nearest',cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks=np.arange(len(classes))
plt.xticks(tick_marks,classes,rotation=45)
plt.yticks(tick_marks,classes)
fmt='.2f'ifnormalizeelse'd'
thresh=cm.max()/2.
fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])):
plt.text(j,i,format(cm[i,j],fmt),
horizontalalignment="center",
color="white"ifcm[i,j]>threshelse"black")
plt.tight_layout()
plt.ylabel('Truelabel')
plt.xlabel('Predictedlabel')
plt.show()
#plt.savefig('confusion_matrix',dpi=200)
cnf_matrix=np.array([
[4101,2,5,24,0],
[50,3930,6,14,5],
[29,3,3973,4,0],
[45,7,1,3878,119],
[31,1,8,28,3936],
])
class_names=['Buildings','Farmland','Greenbelt','Wasteland','Water']
#plt.figure()
#plot_confusion_matrix(cnf_matrix,classes=class_names,
#title='Confusionmatrix,withoutnormalization')
#Plotnormalizedconfusionmatrix
plt.figure()
plot_confusion_matrix(cnf_matrix,classes=class_names,normalize=True,
title='Normalizedconfusionmatrix')
在放矩阵位置,放一下你的混淆矩阵就可以,当然可视化混淆矩阵这一步也可以直接在模型运行中完成。
补充知识:混淆矩阵(Confusionmatrix)的原理及使用(scikit-learn和tensorflow)
原理
在机器学习中,混淆矩阵是一个误差矩阵,常用来可视化地评估监督学习算法的性能.混淆矩阵大小为(n_classes,n_classes)的方阵,其中n_classes表示类的数量.这个矩阵的每一行表示真实类中的实例,而每一列表示预测类中的实例(Tensorflow和scikit-learn采用的实现方式).也可以是,每一行表示预测类中的实例,而每一列表示真实类中的实例(ConfusionmatrixFromWikipedia中的定义).通过混淆矩阵,可以很容易看出系统是否会弄混两个类,这也是混淆矩阵名字的由来.
混淆矩阵是一种特殊类型的列联表(contingencytable)或交叉制表(crosstabulationorcrosstab).其有两维(真实值"actual"和预测值"predicted"),这两维都具有相同的类("classes")的集合.在列联表中,每个维度和类的组合是一个变量.列联表以表的形式,可视化地表示多个变量的频率分布.
使用混淆矩阵(scikit-learn和Tensorflow)
下面先介绍在scikit-learn和tensorflow中计算混淆矩阵的API(ApplicationProgrammingInterface)接口函数,然后在一个示例中,使用这两个API函数.
scikit-learn混淆矩阵函数sklearn.metrics.confusion_matrixAPI接口
skearn.metrics.confusion_matrix( y_true,#array,Goundtrue(correct)targetvalues y_pred,#array,Estimatedtargetsasreturnedbyaclassifier labels=None,#array,Listoflabelstoindexthematrix. sample_weight=None#array-likeofshape=[n_samples],Optionalsampleweights )
在scikit-learn中,计算混淆矩阵用来评估分类的准确度.
按照定义,混淆矩阵C中的元素Ci,j等于真实值为组i,而预测为组j的观测数(thenumberofobservations).所以对于二分类任务,预测结果中,正确的负例数(truenegatives,TN)为C0,0;错误的负例数(falsenegatives,FN)为C1,0;真实的正例数为C1,1;错误的正例数为C0,1.
如果labels为None,scikit-learn会把在出现在y_true或y_pred中的所有值添加到标记列表labels中,并排好序.
Tensorflow混淆矩阵函数tf.confusion_matrixAPI接口
tf.confusion_matrix( labels,#1-DTensorofreallabelsfortheclassificationtask predictions,#1-DTensorofpredictionsforagivenclassification num_classes=None,#Thepossiblenumberoflabelstheclassificationtaskcanhave dtype=tf.int32,#Datatypeoftheconfusionmatrix name=None,#Scopename weights=None,#AnoptionalTensorwhoseshapematchespredictions )
Tensorflowtf.confusion_matrix中的num_classes参数的含义,与scikit-learnsklearn.metrics.confusion_matrix中的labels参数相近,是与标记有关的参数,表示类的总个数,但没有列出具体的标记值.在Tensorflow中一般是以整数作为标记,如果标记为字符串等非整数类型,则需先转为整数表示.如果num_classes参数为None,则把labels和predictions中的最大值+1,作为num_classes参数值.
tf.confusion_matrix的weights参数和sklearn.metrics.confusion_matrix的sample_weight参数的含义相同,都是对预测值进行加权,在此基础上,计算混淆矩阵单元的值.
使用示例
#!/usr/bin/envpython
#-*-coding:utf8-*-
"""
Author:klchang
Description:
Asimpleexamplefortf.confusion_matrixandsklearn.metrics.confusion_matrix.
Date:2018.9.8
"""
from__future__importprint_function
importtensorflowastf
importsklearn.metrics
y_true=[1,2,4]
y_pred=[2,2,4]
#Buildgraphwithtf.confusion_matrixoperation
sess=tf.InteractiveSession()
op=tf.confusion_matrix(y_true,y_pred)
op2=tf.confusion_matrix(y_true,y_pred,num_classes=6,dtype=tf.float32,weights=tf.constant([0.3,0.4,0.3]))
#Executethegraph
print("confusionmatrixintensorflow:")
print("1.default:\n",op.eval())
print("2.customed:\n",sess.run(op2))
sess.close()
#Usesklearn.metrics.confusion_matrixfunction
print("\nconfusionmatrixinscikit-learn:")
print("1.default:\n",sklearn.metrics.confusion_matrix(y_true,y_pred))
print("2.customed:\n",sklearn.metrics.confusion_matrix(y_true,y_pred,labels=range(6),sample_weight=[0.3,0.4,0.3]))
以上这篇利用python中的matplotlib打印混淆矩阵实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。