BankNote

 1 # coding=utf-8
 2 import pandas as pd
 3 import numpy as np
 4 from sklearn import cross_validation
 5 import tensorflow as tf
 6 
 7 global flag
 8 flag=0
 9 
10 def DataPreprocessing():
11     abalone = pd.read_csv("ceshi.csv", sep=',', header=0, keep_default_na=True,na_values=[])
12     X_train=np.array(abalone.iloc[:,:4])
13     Y_train=np.array(abalone.iloc[:,4:])
14     # Y_train=[]
15     # for i in range(len(X_train)):
16     #     if X_train[i][0] == 'M':
17     #         X_train[i][0]=0
18     #     elif X_train[i][0]=='F':
19     #         X_train[i][0]=1
20     #     else:
21     #         X_train[i][0]=2
22     #
23     # for i in range(len(Y_train_)):
24     #
25     #     #print(Y_train[i][0])
26     #     Y_train.append(Y_train_[i][0])
27 
28     # print(X_train)
29     # print(len(X_train))
30     # print(Y_train)
31     # print(len(Y_train))
32    # print(min(Y_train))
33    # print(max(Y_train))
34 
35     return cross_validation.train_test_split(X_train,Y_train,test_size=0.25,random_state=0,stratify=Y_train)
36 
37 
38 def GetInputs():
39     global flag
40     X_train, X_test, Y_train, Y_test = DataPreprocessing()
41 
42     #print(X_train)
43     # print(len(X_test))
44     # print(len(Y_train))
45     # print(len(Y_test))
46 
47 
48     #X_train[X_train.isnull().any(axis=1)]
49     #X_train.fillna('',inplace=True)
50 
51     print(X_train)
52     print(Y_test)
53 
54     x_train=tf.constant(X_train)
55     y_train=tf.constant(Y_train)
56     x_test=tf.constant(X_test)
57     y_test=tf.constant(Y_test)
58 
59     print(x_train)
60     print(y_train)
61     print(x_test)
62     print(y_test)
63 
64     if flag==0:
65         return x_train,y_train
66     else:
67         return x_test,y_test
68 
69 
70 def Main():
71 
72     global flag
73 
74     feature_columns=[tf.contrib.layers.real_valued_column("",dimension=4)]
75 
76     clf=tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[10,20,10],n_classes=2,model_dir="/home/jiangjing/TensorflowModel/banknote")
77 
78     clf.fit(input_fn=GetInputs,steps=2000)
79 
80     flag=1
81     accuracy_score=clf.evaluate(input_fn=GetInputs,steps=1)["accuracy"]
82 
83     print("nTest Accuracy:{0:f}".format(accuracy_score))
84 
85 if __name__ =="__main__":
86     #DataPreprocessing()
87 
88     Main()
89 
90 exit(0)

 

posted @ 2018-05-27 21:04  Run_For_Love  阅读(270)  评论(0编辑  收藏  举报