常用函数框架
1、混淆矩阵
import itertools def plot_condusion_matrix(cm,classes, title = 'Confusion matrix', cmap = plt.cm.Blues): plt.imshow(cm,interpolation='nearest',cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks,classes,rotation=0) plt.yticks(tick_marks,classes) thresh = cm.max()/2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label')
2、决策树可视化
def dec_tree(model,feature_names,tagret_names) import os os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/' dot_data = \ tree.export_graphviz( model, # 模型 out_file=None,
class_names=target_names, feature_names=feature_names, #特征名字 filled=True, impurity=False, rounded=True ) import pydotplus graph = pydotplus.graph_from_dot_data(dot_data) #graph.get_nodes()[7].set_fillcolor('#FFF2DD') from IPython.display import Image Image(graph.create_png()) graph.write_jpg('graph_jpg') #写入ipg文件
3、训练模型时,需要遍历数据集并不断读取小批量数据样本,这里定义一个函数,每次返回batch_size个随机样本的特征和标签
def data_iter(batch_size,feature,labels): num_example = len(feature) indices = list(range(num_example)) random.shuffle(indices) # 让样本的读取是随机的 for i in range(0,num_example,batch_size): j = nd.array(indices[i:min(i+batch_size,num_example)]) # 增加一个min函数的作用是:当最后一段数据长度比batch_size小的时候,可以直接返回整个数据,防止报错 yield feature.take(j),labels.take(j) # take函数根据索引返回对应元素