题目
识别手写数字
做法
开始做kaggle的第一套题,识别手写数字。每个数字是28*28的一个向量,朴素的跑了一个KNN,距离用的是欧几里得距离。最终成绩0.96
def knn(inX,num):
dataSet = trainMat
labels = labelList
k = 3
dataSetSize = dataSet.shape[0]
diffMat = tile(inX,(dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sumDiffMat = sqDiffMat.sum(axis=1)
distances = sumDiffMat**0.5
sortedDistances = distances.argsort()
classCount = {}
for i in range(k):
vote = labels[sortedDistances[i]]
classCount[vote] = classCount.get(vote,0) + 1
# sortedClassCount = sorted(classCount,key=itemgetter('vote'))
max = 0
ans = ''
for k,v in classCount.items():
if(v>max):
ans = k
max = v
print(str(num+1) + ' = ' + ans)
outFile.write(str(num+1) + ',' + ans + '\n')
return
以后学到更多的知识再做优化。
在写法上,用上了Python的多线程来处理,节省了一定的时间
from multiprocessing.dummy import Pool
outFile = open("out2.csv",'w')
pool = Pool()
pool.starmap(knn,zip(testMat,range(n)))
pool.close()
pool.join()
outFile.close()
代码
from numpy import *
import csv
from multiprocessing.dummy import Pool
def knn_warp(args):
return knn(*args)
def knn(inX,num):
dataSet = trainMat
labels = labelList
k = 3
dataSetSize = dataSet.shape[0]
diffMat = tile(inX,(dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sumDiffMat = sqDiffMat.sum(axis=1)
distances = sumDiffMat**0.5
sortedDistances = distances.argsort()
classCount = {}
for i in range(k):
vote = labels[sortedDistances[i]]
classCount[vote] = classCount.get(vote,0) + 1
# sortedClassCount = sorted(classCount,key=itemgetter('vote'))
max = 0
ans = ''
for k,v in classCount.items():
if(v>max):
ans = k
max = v
print(str(num+1) + ' = ' + ans)
outFile.write(str(num+1) + ',' + ans + '\n')
return
def readTrain(row,i):
labelList[i] = row['label']
for x in range(0, 784):
trainMat[i, x] = int(row['pixel' + str(x)])
print(str(i))
def readTest(row,i):
for x in range(0, 784):
testMat[i, x] = int(row['pixel' + str(x)])
print(str(i))
global labelList
global trainMat
global outFile
if __name__ == '__main__':
f = open('train.csv')
m = len(f.readlines())
m = m - 1
labelList = list(range(m))
trainMat = zeros((m,784))
f.close()
with open('train.csv') as f:
f_csv = csv.DictReader(f)
pool = Pool()
pool.starmap(readTrain, zip(f_csv, range(m)))
pool.close()
pool.join()
f = open('test.csv')
n = len(f.readlines())
n = n - 1
testMat = zeros((n,784))
f.close()
with open('test.csv') as f:
f_csv = csv.DictReader(f)
pool = Pool()
pool.starmap(readTest, zip(f_csv, range(n)))
pool.close()
pool.join()
outFile = open("out2.csv",'w')
pool = Pool()
pool.starmap(knn,zip(testMat,range(n)))
pool.close()
pool.join()
outFile.close()
本文介绍使用K近邻算法(KNN)进行手写数字识别的过程,通过Python实现并对Kaggle的手写数字数据集进行了初步尝试,最终取得了0.96的成绩。
Digit Recognizer&spm=1001.2101.3001.5002&articleId=54578125&d=1&t=3&u=da8a84abcd704315bcc04e1c3d7b4c72)
744

被折叠的 条评论
为什么被折叠?



