JavaCnn项目注解
JavaCnn项目注解
该JavaCnn项目旨在用Java语言构造一个完整的卷积神经网络,实现训练一个手写字符识别模型,并预测。该项目可以帮助我们深入到Cnn的底层原理实现,通过阅读分析该项目代码,既可以提高对Java语言的掌握,也加深了对卷积神经网络的认识。
虽然项目的功能是“识别”,但其本质上,是一个分类的过程。
项目的入口是RunCnn类,Main()函数里开了个定时器,并根据CPU核数分了线程数。// Todo
项目分训练和预测两个模块:
训练步骤:
- 创建新模型
- 载入训练集
- 载入测试集
- 调用cnn对象的train()方法
预测步骤:
- 调用cnn对象的loadModels()方法载入模型,该方法返回Cnn对象
- 载入测试集
- 测试集初始化 // Todo
- 调用cnn对象的predict()方法
predict()方法详解
- 保存模型每一层的是否开启Dropout状态,存到save[]数组中;
- 关掉每一层的dropout,确保预测的时候所有权重都参与计算;
- 初始化batch的记录为0。 // Todo
- 对于每一张图片,都进行一次forward()正向传播计算。
- 对于每一张图片,正向传播的输出是 x个数(x是样本类别数),将这10个数字存入分类预测结果的数组中。
- 从分类预测结果的数组(样本一共有x类,数组的长度就是x)中取出数组最大值对应的下标,将其和图片对应的label对比,若值相等,则正确个数 +1 。
- 计算正常率: 正确个数/总测试图片数*100%。
- 将每一层的dropout状态变回原来的状态,使训练过程得以继续。
Record类注解
一张图片对应了一个record实例对象,Record类由两个属性组成,数组attrs[]保存一张图片的所有像素值,像素值进行了归一化处理,范围为0~1。
public Record(double[] data)
{
lable = data[data.length-1];
attrs = Arrays.copyOfRange(data, 0, data.length-1);
}