python使用KNN算法识别手写数字
本文实例为大家分享了python使用KNN算法识别手写数字的具体代码,供大家参考,具体内容如下
#-*-coding:utf-8-*- #pipinstallnumpy importos importos.path fromnumpyimport* importoperator importtime fromosimportlistdir """ 描述: KNN算法实现分类器 参数: inputPoint:测试集 dataSet:训练集 labels:类别标签 k:K个邻居 返回值: 该测试数据的类别 """ defclassify(inputPoint,dataSet,labels,k): dataSetSize=dataSet.shape[0]#已知分类的数据集(训练集)的行数 #先tile函数将输入点拓展成与训练集相同维数的矩阵,再计算欧氏距离 diffMat=tile(inputPoint,(dataSetSize,1))-dataSet#样本与训练集的差值矩阵 #print(inputPoint); sqDiffMat=diffMat**2#sqDiffMat的数据类型是nump提供的ndarray,这不是矩阵的平方,而是每个元素变成原来的平方。 sqDistances=sqDiffMat.sum(axis=1)#计算每一行上元素的和 #print(sqDistances); distances=sqDistances**0.5#开方得到欧拉距离矩阵 #print(distances); sortedDistIndicies=distances.argsort()#按distances中元素进行升序排序后得到的对应下标的列表,argsort函数返回的是数组值从小到大的索引值 #print(sortedDistIndicies); #classCount数据类型是这样的{0:2,1:2},字典key:value classCount={} #选择距离最小的k个点 foriinrange(k): voteIlabel=labels[sortedDistIndicies[i]] #print(voteIlabel) #类别数加1 classCount[voteIlabel]=classCount.get(voteIlabel,0)+1 print(classCount)#{1:1,7:2} #按classCount字典的第2个元素(即类别出现的次数)从大到小排序 sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) print(sortedClassCount)#[(7,2),(1,1)] returnsortedClassCount[0][0] """ 描述: 读取指定文件名的文本数据,构建一个矩阵 参数: 文本文件名称 返回值: 一个单行矩阵 """ defimg2vector(filename): returnVect=[] fr=open(filename) foriinrange(32): lineStr=fr.readline() forjinrange(32): returnVect.append(int(lineStr[j])) returnreturnVect """ 描述: 从文件名中解析分类数字,比如由0_0.txt得知这个文本代表的数字分类是0 参数: 文本文件名称 返回值: 一个代表分类的数字 """ defclassnumCut(fileName): fileStr=fileName.split('.')[0] classNumStr=int(fileStr.split('_')[0]) returnclassNumStr """ 描述: 构建训练集数据向量,及对应分类标签向量 参数: 无 返回值: hwLabels:分类标签矩阵 trainingMat:训练数据集矩阵 """ deftrainingDataSet(): hwLabels=[] trainingFileList=listdir('trainingDigits')#获取目录内容 m=len(trainingFileList) #zeros返回全部是0的矩阵,参数是行和列 trainingMat=zeros((m,1024))#m维向量的训练集 foriinrange(m): #print(i); fileNameStr=trainingFileList[i] hwLabels.append(classnumCut(fileNameStr)) trainingMat[i,:]=img2vector('trainingDigits/%s'%fileNameStr) returnhwLabels,trainingMat """ 描述: 主函数,最终打印识别了多少个数字以及识别的错误率 参数: 无 返回值: 无 """ defhandwritingTest(): """ hwLabels,trainingMat是标签和训练数据, hwLabels是一个一维矩阵,代表每个文本对应的标签(即文本所代表的数字类型) trainingMat是一个多维矩阵,每一行都代表一个文本的数据,每行有1024个数字(0或1) """ hwLabels,trainingMat=trainingDataSet()#构建训练集 testFileList=listdir('testDigits')#获取测试集 errorCount=0.0#错误数 mTest=len(testFileList)#测试集总样本数 t1=time.time() foriinrange(mTest): fileNameStr=testFileList[i] classNumStr=classnumCut(fileNameStr) #img2vector返回一个文本对应的一维矩阵,1024个0或者1 vectorUnderTest=img2vector('testDigits/%s'%fileNameStr) #调用knn算法进行测试 classifierResult=classify(vectorUnderTest,trainingMat,hwLabels,3) #打印测试出来的结果和真正的结果,看看是否匹配 print("theclassifiercamebackwith:%d,therealansweris:%d"%(classifierResult,classNumStr)) #如果测试出来的值和原值不相等,errorCount+1 if(classifierResult!=classNumStr): errorCount+=1.0 print("\nthetotalnumberoftestsis:%d"%mTest)#输出测试总样本数 print("thetotalnumberoferrorsis:%d"%errorCount)#输出测试错误样本数 print("thetotalerrorrateis:%f"%(errorCount/float(mTest)))#输出错误率 t2=time.time() print("Costtime:%.2fmin,%.4fs."%((t2-t1)//60,(t2-t1)%60))#测试耗时 """ 描述: 指定handwritingTest()为主函数 """ if__name__=="__main__": handwritingTest()
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。