学习进度笔记

学习进度笔记09

TensorFlow K近邻算法

import numpy as np  

import tensorflow as tf  

from tensorflow.examples.tutorials.mnist import input_data  

import os  

os.environ["CUDA_VISIBLE_DEVICES"]="0"  

mnist =input_data.read_data_sets("/home/yxcx/tf_data/MNIST_data",one_hot=True)  

Xtr,Ytr=mnist.train.next_batch(5000)  

Xte,Yte=mnist.test.next_batch(200)  

 

#tf Graph Input  

xtr=tf.placeholder("float",[None,784])  

xte=tf.placeholder("float",[784])  

distance =tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)  

pred=tf.argmin(distance,0)  

  

accuracy=0  

init=tf.global_variables_initializer()  

 

 

#Start training  

with tf.Session() as sess:  

    sess.run(init)  

    for i in range(len(Xte)):  

        #Get nearest nerighbor  

        nn_index=sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})  

        print("Test",i ,"Prediction:",np.argmax(Ytr[nn_index]),"True Class:",np.argmax(Yte[i]))  

        if np.argmax(Ytr[nn_index])==np.argmax(Yte[i]):  

            accuracy+=1./len(Xte)  

    print("Done!")  

    print("accuacy:" ,accuracy)  

posted @ 2021-01-18 06:59  城南漠北  阅读(26)  评论(0编辑  收藏  举报