您的位置:首页 > 编程语言 > Python开发

tensorflow 学习专栏(五):在mnist数据集上使用tensorflow实现临近算法(Nearest-Neighbor)进行手写数字识别

2018-02-22 10:54 1671 查看
实现最临近算法的具体步骤如下:
1.为了判断未知实列的类别,选取已知类别的实例作为参考(如图所示:Xu点为未知类别的点);
2.选择参数K(在本实验中K=所有已知实列点的数目,即将所有已知种类的数据用作参考数据);
3.计算未知实例点与所有已知实例点的距离distance;
4.将到该未知点距离最近的点的类别作为该点的类别;



import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('./mnist',one_hot=True) #导入mnist数据集
Xtr,Ytr = mnist.train.next_batch(5000) #5000个样本作为已知类别的样本
Xte,Yte = mnist.test.next_batch(200) #200个未知类别的样本

xtr = tf.placeholder(tf.float32,[None,784]) #设置占位符放置5000个样本[5000,784]
xte = tf.placeholder(tf.float32,[784]) #每次取一个未知样本

distance = tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),axis=1) #计算未知样本到每一个(5000)已知样本的距离
pred = tf.arg_min(distance,0) #返回:到未知样本距离最短的已知样本
accuracy = 0

sess= tf.Session() #初始化
sess.run(tf.global_variables_initializer())

for i in range(len(Xte)):
prediction_ = sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i]})
print('Test:',i,'Prediction:',np.argmax(Ytr[prediction_]),\ #显示预测类别与真实类别
'True Class:',np.argmax(Yte[i]))
if np.argmax(Ytr[prediction_]) == np.argmax(Yte[i]):
accuracy += 1./len(Xte) #显示准确率
print('Accuracy:',accuracy)运行结果如下:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐