机器学习(二)——决策树方法构建分类模型(matlab)

决策树(Decision Tree),又称判断树,它是一种以树形数据结构来展示决策规则和分类结果的模型,作为一种归纳学习算法,其重点是将看似无序、杂乱的已知实例,通过某种技术手段将它们转化成可以预测未知实例的树状模型,每一条从根结点(对最终分类结果贡献最大的属性)到叶子结点(最终分类结果)的路径都代表一条决策的规则。
决策树主要包括以下几个步骤:

1、选择最佳特征进行分割:
决策树的每个节点代表一个特征,每个分支代表该特征的一个取值或取值范围。
选择一个特征来分割数据,使得分割后的子数据集在目标变量(分类标签)上的纯度最高。常用的度量标准包括信息增益(Information Gain)、信息增益比(Information Gain Ratio)和基尼系数(Gini Index)。

2、递归地分割数据:
从根节点开始,根据最佳特征对数据进行分割,生成子节点。
对每个子节点,重复选择最佳特征进行进一步分割,直到达到停止条件。停止条件可以是所有特征都已使用、节点中的样本属于同一类别、或节点中的样本数量小于某个预设值。

3、生成叶子节点:
当无法继续分割或达到停止条件时,当前节点成为叶子节点,并分配一个分类标签。
叶子节点的分类标签通常是该节点中样本的多数类。

4、剪枝(Pruning):
为防止过拟合,可能需要对生成的决策树进行剪枝。剪枝可以分为预剪枝(Pre-pruning)和后剪枝(Post-pruning)。
预剪枝是在构建过程中提前停止分割,例如设置最大树深度或最小样本数。
后剪枝是在生成完整树后,移除一些子树或节点,以简化模型,保留对测试集的良好泛化能力。

本文同样使用了留一交叉验证的方法验证模型的分类性能,具体应用场景可以参考:机器学习(一)——递归特征消除法实现SVM(matlab)
在matlab中直接调用相关的模型就可以了,并不需要自己构建模型。模型可以帮忙选出重要的特征索引,然后选用重要特征构建最后的分类模型。

同时本文还在后面画出了分类结果的混淆矩阵,可以更直观地看到分类模型的分类效果。

通过matlab实现的代码为:

%res是我的特征矩阵,其中第一类是样本标签,后面的列为样本的特征。

labels = res(:, 1);  % 第一列是标签
features = res(:, 2:end);  % 后面的列是特征
% features = zscore(features);   %使用决策树进行分类时不需要进行归一化处理



%%   使用树模型的特征重要性进行特征选择,并将选择后的特征以及对应标签存在important_features_res中

% 构建决策树模型
tree = fitctree(features, labels);

% 获取特征重要性
importance = predictorImportance(tree);

% 可视化特征重要性
bar(importance);
xlabel('特征索引');
ylabel('特征重要性');
title('特征重要性');


% 寻找大于 0 的特征重要性的索引
important_features_idx = find(importance > 0);

% 提取对应的特征索引
important_features = important_features_idx(:);

% 显示重要特征索引
disp('重要特征索引:');
disp(important_features);

%将重要性大于0的特征索引提取出来
important_features_idx = find(importance > 0);

% 提取重要特征和标签
important_features_res = res(:, [1, important_features_idx + 1]);

labels_importance = important_features_res(:, 1);  % 第一列是标签
features_importance = important_features_res(:, 2:end); %后面的列是特征



%%  运用留一交叉验证的方法验证模型的分类效果

num_samples = size(important_features_res, 1);
accuracy = zeros(num_samples, 1);
predicted_labels=[];
predictedScores=zeros(56,2);
  
for i = 1:num_samples
    % 将第 i 个样本作为验证集,其余样本作为训练集
    train_features = features_importance;
    train_labels = labels_importance ;
    train_features(i, :) = [];
    train_labels(i) = [];
    
    test_feature = features_importance(i, :);
    test_label = labels_importance (i);
    
    % 构建决策树模型
    tree = fitctree(train_features, train_labels);
    
    % 对第 i 个样本进行预测
    [predicted_label,predictedScore] = predict(tree, test_feature);
    predictedScores(i,:)=predictedScore;

    predicted_labels(i,1)=predicted_label;
    % 计算准确率
    accuracy(i) = predicted_label == test_label;
end

% 计算平均准确率
mean_accuracy = mean(accuracy);

disp('分类的准确性为');
disp(mean_accuracy);


% predicted_labels 是预测的标签,true_labels 是真实的标签
% 计算混淆矩阵
C = confusionmat(labels, predicted_labels);


% 显示混淆矩阵
figure;
disp('混淆矩阵:');
disp(C);

% 自定义渐变蓝色颜色图
nColors = 256; % 颜色的数量
blueCmap = [linspace(0.9, 0, nColors)', linspace(0.9, 0, nColors)', linspace(1, 0.5, nColors)'];

% 绘制混淆矩阵的热力图
h = heatmap(C, 'Colormap', blueCmap, 'ColorbarVisible', 'on', 'XLabel', '预测标签', 'YLabel', '真实标签', 'Title', '基于Decision Tree方法的混淆矩阵');

% 修改混淆矩阵的标签
h.XDisplayLabels = {'LPE', 'NC'}; % 修改横坐标标签
h.YDisplayLabels = {'LPE', 'NC'}; % 修改纵坐标标签

% 更新热力图
drawnow;


posted on 2024-06-28 21:39  一只嘤嘤怪  阅读(100)  评论(0编辑  收藏  举报

导航