ROC检验分类树性能
直接上代码:
001 #############################################################
002 ############# 读取excel文件 ######################
003 #############################################################
004 #root<-"C:/Users/liming/Desktop/写书/chap7/cbc/"
005 root<-"E:/work/写书/chap7/cbc(可以使用-但是未写书书中)/"
006 file<-paste(root,"data.xls",sep="")
007 library (RODBC)
008 excel_file <- odbcConnectExcel(file)
009 data<- sqlFetch ( excel_file,"data")[,1:17]#由于excel文件问题
010 close ( excel_file )
011 ##############################################################
012 ####数据处理
013 ##############################################################
014 data$chineseCook<-as.factor(ifelse(data$chineseCook==0,0,1))#如果购买Cook书则为1,没有购买则为0
015 data$chineseHist<-as.factor(ifelse(data$chineseHist==0,0,1))#如果购买Atlas书则为1,没有购买则为0
016 data$chineseArt<-as.factor(ifelse(data$chineseArt==0,0,1))#如果购买Art书则为1,没有购买则为0
017 data$Gender<-as.factor(data$Gender)#性别
018 data$HanDynastyArt<-as.factor(data$HanDynastyArt)#目标变量
019 ##############################################################
020 ####对不平衡目标样本的样本处理
021 ##############################################################
022 balance<-function(data,yval){
023 y.vector<-with(data,get(yval))#在data数据框中读取列名称为yval参数的向量,例如如果yval设置为HanDynastyArt则y.vector就是data$HanDynastyArt
024 index.1<-which(y.vector==1)
025 index.0<-which(y.vector==0)
026 index.1<-sample(index.1,length(index.0),replace=T)
027 result<-data[sample(c(index.0,index.1)),]
028 result
029 }
030 data<-balance(data,"HanDynastyArt")
031 ##############################################################
032 ####分割训练集数据和测试数据
033 ##############################################################
034 apart.data<-function(data,train.data.persent=0.7){
035 train.index<-sample(c(1:nrow(data)),round(nrow(data)*train.data.persent))
036 data.train<-data[train.index,]
037 data.test<-data[-c(train.index),]
038 result<-list(train=data.train,test=data.test)
039 result
040 }
041 p.data<-apart.data(data)
042 data.train<-p.data$train
043 data.test<-p.data$test
044 ###############################################################
045 mod.formula<-as.formula("HanDynastyArt~Gender+M+R+F+FirstPurch+ChildBks+YouthBks+CookBks+DoItYBks+RefBks+ArtBks+GeogBks+chineseCook+chineseHist+chineseArt")
046 #### party ####
047 library("party")
048 ctree.sol<-ctree(mod.formula,data=data.train,control= ctree_control(mincriterion =0.95))
049 #plot(ctree.sol)
050 #### rpart ####
051 library("rpart")
052 rpart.sol<-rpart(mod.formula,data=data.train,control=list(cp=0.001))
053 #plot(rpart.sol,uniform=TRUE,compress=TRUE,lty=3,branch=0.7)
054 #text(rpart.sol,all=TRUE,digits=7,use.n=TRUE,cex=0.9,xpd=TRUE)
055 #### logit ####
056 glm.sol<-glm(mod.formula,data=data.train,family=binomial("logit"))
057 #### nnet ####
058 library(nnet)
059 nnet.sol<-nnet(mod.formula,data=data.train,size=30,maxit=1000)#size越高 树越大
060 #### svm ####
061 library(e1071)
062 svm.sol<-svm(mod.formula,data=data.train,probability = TRUE)#probability为T表示计算预测数值为取1和0的概率,
063 ##############################################################
064 ####性能检验
065 ##############################################################
066 library(ROCR)
067 sol.performance<-function(sol,test,add.logic=F,color=NA){
068 ########使用prediction函数初始化数据########
069 test.real<-as.numeric(ifelse(test$HanDynastyArt==0,0,1))#把目标变量的factor0变为numeric0。factor1变为numeric1。
070 sol.class<-class(sol)[1]
071 if(sol.class=="BinaryTree"){#ctree函数
072 test.pred<-as.numeric(predict(sol,test))
073 test.pred<-ifelse(test.pred==1,0,1)#把目标变量的1变为0;2变为1
074 }else{if(sol.class=="rpart"){#rpart函数
075 test.pred<-predict(sol,test)[,2]
076 }else{if(sol.class=="glm"){#glm函数
077 glm.pred<-predict(sol,test)
078 test.pred<-1/(1+exp(-glm.pred))
079 }else{if(sol.class=="nnet.formula"){#nnet函数
080 test.pred<-predict(sol,test)
081 }else{if(sol.class=="svm.formula"){#svm函数
082 test.pred<-attr(predict(sol,test,probability=T),"probabilities")[,2]
083 }else{
084 print("ERROR:sol输入有误!")
085 return()
086 }}}}}
087 predictions<-prediction(test.pred,test.real)
088 ########计算混淆矩阵并使用performance函数计算灵敏度和auc########
089 if(sol.class=="BinaryTree"){#ctree函数
090 print("混淆矩阵:")
091 print(table(test.pred,test.real,dnn=c("预测数值","真实数值")))
092 sens<-performance(predictions,'sens')@y.values[[1]][2]
093 print(paste("灵敏度(Sensitivity):",sens,sep=""))
094 spec<-performance(predictions,'spec')@y.values[[1]][2]
095 print(paste("特指度(Specicity):",spec,sep=""))
096 }else{if(sol.class=="rpart"){#rpart函数
097 print("混淆矩阵:")
098 tmp<-ifelse(as.numeric(predict(sol,test,type="class"))==1,0,1)#把目标变量的1变为0;2变为1
099 print(table(tmp,test.real,dnn=c("预测数值","真实数值")))
100 auc<-performance(predictions,'auc')@y.values
101 print(paste("ROC曲线下的面积(auc):",auc,sep=""))
102 }else{if(sol.class=="glm"){#glm函数
103 print("混淆矩阵:")
104 tmp<-ifelse(test.pred>0.5,1,0)
105 print(table(tmp,test.real,dnn=c("预测数值","真实数值")))
106 auc<-performance(predictions,'auc')@y.values
107 print(paste("ROC曲线下的面积(auc):",auc,sep=""))
108 }else{if(sol.class=="nnet.formula"){#nnet函数
109 print("混淆矩阵:")
110 tmp<-as.numeric(predict(sol,test,type="class"))
111 print(table(tmp,test.real,dnn=c("预测数值","真实数值")))
112 auc<-performance(predictions,'auc')@y.values
113 print(paste("ROC曲线下的面积(auc):",auc,sep=""))
114 }else{if(sol.class=="svm.formula"){#svm函数
115 print("混淆矩阵:")
116 tmp<-ifelse(as.numeric(predict(sol,test,type="class"))==1,0,1)#把目标变量的1变为0;2变为1
117 print(table(tmp,test.real,dnn=c("预测数值","真实数值")))
118 auc<-performance(predictions,'auc')@y.values
119 print(paste("ROC曲线下的面积(auc):",auc,sep=""))
120 }else{
121 print("ERROR:sol输入有误!")
122 return()
123 }}}}}
124 ########绘制ROC曲线####
125 #如果predict返回的是0/1(tree模型)则roc是一个点的折线,如果是0-1的概率连续值(logic nnet等)则roc是多个点组成的曲线
126 plot(performance(predictions,'tpr','fpr'),colorize=T,main="ROC图",ylab="真正率(TPR)=灵敏度(Sensitivity)",xlab="假正率(FPR)=1-特指度(1-Specicity)",add=add.logic,colorize.palette=color)
127
128 }
129 col<-rainbow(5,start=0,end=4/6)
130 sol.performance(ctree.sol,data.test,F,col[1])
131 sol.performance(rpart.sol,data.test,T,col[2])
132 sol.performance(glm.sol,data.test,T,col[3])
133 sol.performance(nnet.sol,data.test,T,col[4])
134 sol.performance(svm.sol,data.test,T,col[5])
135 id=c("ctree模型","rpart模型","glm模型","nnet模型","svm模型")
136 legend("bottomright",legend=id,horiz=T,pch=15,col=col,cex=0.8,bty="n")
数据源下载链接为http://vdisk.weibo.com/s/sFZzV0wnhRPN;
该数据集是某图书出版社研究用户是否会购买新书而做的调查问卷结果,其中的HanDynastyArt是目标变量。
输出的结果为,表示不同的模型下roc曲线: