JAVA实现BP神经网络算法
工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测。
简介
BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法进行权值与阈值的调整。在20世纪80年代,几位不同的学者分别开发出了用于训练多层感知机的反向传播算法,David Rumelhart和James McClelland提出的反向传播算法是最具影响力的。其包含BP的两大主要过程,即工作信号的正向传播与误差信号的反向传播,分别负责了神经网络中输出的计算与权值和阈值更新。工作信号的正向传播是通过计算得到BP神经网络的实际输出,误差信号的反向传播是由后往前逐层修正权值与阈值,为了使实际输出更接近期望输出。
(1)工作信号正向传播。输入信号从输入层进入,通过突触进入隐含层神经元,经传递函数运算后,传递到输出层,并且在输出层计算出输出信号传出。当工作信号正向传播时,权值与阈值固定不变,神经网络中每层的状态只与前一层的净输出、权值和阈值有关。若正向传播在输出层获得到期望的输出,则学习结束,并保留当前的权值与阈值;若正向传播在输出层得不到期望的输出,则在误差信号的反向传播中修正权值与阈值。
(2)误差信号反向传播。在工作信号正向传播后若得不到期望的输出,则通过计算误差信号进行反向传播,通过计算BP神经网络的实际输出与期望输出之间的差值作为误差信号,并且由神经网络的输出层,逐层向输入层传播。在此过程中,每向前传播一层,就对该层的权值与阈值进行修改,由此一直向前传播直至输入层,该过程是为了使神经网络的结果与期望的结果更相近。
当进行一次正向传播和反向传播后,若误差仍不能达到要求,则该过程继续下去,直至误差满足精度,或者满足迭代次数等其他设置的结束条件。
推导请见 https://zh.wikipedia.org/wiki/%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95
BPNN结构
该BPNN为单输入层单隐含层单输出层结构
项目结构
介绍一些用到的类
- ActivationFunction:激活函数的接口
- BPModel:BP模型实体类
- BPNeuralNetworkFactory:BP神经网络工厂,包括训练BP神经网络,计算,序列化等功能
- BPParameter:BP神经网络参数实体类
- Matrix:矩阵实体类
- Sigmoid:Sigmoid传输函数,实现了ActivationFunction接口
- MatrixUtil:矩阵工具类
实现代码
Matrix实体类
模拟了矩阵的基本运算方法。
package com.top.matrix; import com.top.constants.OrderEnum; import java.io.Serializable; public class Matrix implements Serializable { private double[][] matrix; //矩阵列数 private int matrixColCount; //矩阵行数 private int matrixRowCount; /** * 构造一个空矩阵 */ public Matrix() { this.matrix = null; this.matrixColCount = 0; this.matrixRowCount = 0; } /** * 构造一个matrix矩阵 * @param matrix */ public Matrix(double[][] matrix) { this.matrix = matrix; this.matrixRowCount = matrix.length; this.matrixColCount = matrix[0].length; } /** * 构造一个rowCount行colCount列值为0的矩阵 * @param rowCount * @param colCount */ public Matrix(int rowCount,int colCount) { double[][] matrix = new double[rowCount][colCount]; for (int i = 0; i < rowCount; i++) { for (int j = 0; j < colCount; j++) { matrix[i][j] = 0; } } this.matrix = matrix; this.matrixRowCount = rowCount; this.matrixColCount = colCount; } /** * 构造一个rowCount行colCount列值为val的矩阵 * @param val * @param rowCount * @param colCount */ public Matrix(double val,int rowCount,int colCount) { double[][] matrix = new double[rowCount][colCount]; for (int i = 0; i < rowCount; i++) { for (int j = 0; j < colCount; j++) { matrix[i][j] = val; } } this.matrix = matrix; this.matrixRowCount = rowCount; this.matrixColCount = colCount; } public double[][] getMatrix() { return matrix; } public void setMatrix(double[][] matrix) { this.matrix = matrix; this.matrixRowCount = matrix.length; this.matrixColCount = matrix[0].length; } public int getMatrixColCount() { return matrixColCount; } public int getMatrixRowCount() { return matrixRowCount; } /** * 获取矩阵指定位置的值 * * @param x * @param y * @return */ public double getValOfIdx(int x, int y) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (x > matrixRowCount - 1) { throw new IllegalArgumentException("索引x越界"); } if (y > matrixColCount - 1) { throw new IllegalArgumentException("索引y越界"); } return matrix[x][y]; } /** * 获取矩阵指定行 * * @param x * @return */ public Matrix getRowOfIdx(int x) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (x > matrixRowCount - 1) { throw new IllegalArgumentException("索引x越界"); } double[][] result = new double[1][matrixColCount]; result[0] = matrix[x]; return new Matrix(result); } /** * 获取矩阵指定列 * * @param y * @return */ public Matrix getColOfIdx(int y) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (y > matrixColCount - 1) { throw new IllegalArgumentException("索引y越界"); } double[][] result = new double[matrixRowCount][1]; for (int i = 0; i < matrixRowCount; i++) { result[i][0] = matrix[i][y]; } return new Matrix(result); } /** * 设置矩阵中x,y位置元素的值 * @param x * @param y * @param val */ public void setValue(int x, int y, double val) { if (x > this.matrixRowCount - 1) { throw new IllegalArgumentException("行索引越界"); } if (y > this.matrixColCount - 1) { throw new IllegalArgumentException("列索引越界"); } this.matrix[x][y] = val; } /** * 矩阵乘矩阵 * * @param a * @return * @throws IllegalArgumentException */ public Matrix multiple(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixColCount != a.getMatrixRowCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][a.getMatrixColCount()]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < a.getMatrixColCount(); j++) { for (int k = 0; k < matrixColCount; k++) { result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j]; } } } return new Matrix(result); } /** * 矩阵乘一个数字 * * @param a * @return */ public Matrix multiple(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] * a; } } return new Matrix(result); } /** * 矩阵点乘 * * @param a * @return */ public Matrix pointMultiple(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] * a.getMatrix()[i][j]; } } return new Matrix(result); } /** * 矩阵除一个数字 * @param a * @return * @throws IllegalArgumentException */ public Matrix divide(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] / a; } } return new Matrix(result); } /** * 矩阵加法 * * @param a * @return */ public Matrix plus(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] + a.getMatrix()[i][j]; } } return new Matrix(result); } /** * 矩阵加一个数字 * @param a * @return * @throws IllegalArgumentException */ public Matrix plus(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] + a; } } return new Matrix(result); } /** * 矩阵减法 * * @param a * @return */ public Matrix subtract(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] - a.getMatrix()[i][j]; } } return new Matrix(result); } /** * 矩阵减一个数字 * @param a * @return * @throws IllegalArgumentException */ public Matrix subtract(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] - a; } } return new Matrix(result); } /** * 矩阵行求和 * * @return */ public Matrix sumRow() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][1]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][0] += matrix[i][j]; } } return new Matrix(result); } /** * 矩阵列求和 * * @return */ public Matrix sumCol() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[1][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[0][j] += matrix[i][j]; } } return new Matrix(result); } /** * 矩阵所有元素求和 * * @return */ public double sumAll() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double result = 0; for (double[] doubles : matrix) { for (int j = 0; j < matrixColCount; j++) { result += doubles[j]; } } return result; } /** * 矩阵所有元素求平方 * * @return */ public Matrix square() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] * matrix[i][j]; } } return new Matrix(result); } /** * 矩阵所有元素求N次方 * * @return */ public Matrix pow(double n) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = Math.pow(matrix[i][j],n); } } return new Matrix(result); } /** * 矩阵转置 * * @return */ public Matrix transpose() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixColCount][matrixRowCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[j][i] = matrix[i][j]; } } return new Matrix(result); } /** * 截取矩阵 * @param startRowIndex 开始行索引 * @param rowCount 截取行数 * @param startColIndex 开始列索引 * @param colCount 截取列数 * @return * @throws IllegalArgumentException */ public Matrix subMatrix(int startRowIndex,int rowCount,int startColIndex,int colCount) throws IllegalArgumentException { if (startRowIndex + rowCount > matrixRowCount) { throw new IllegalArgumentException("行索引越界"); } if (startColIndex + colCount> matrixColCount) { throw new IllegalArgumentException("列索引越界"); } double[][] result = new double[rowCount][colCount]; for (int i = startRowIndex; i < startRowIndex + rowCount; i++) { if (startColIndex + colCount - startColIndex >= 0) System.arraycopy(matrix[i], startColIndex, result[i - startRowIndex], 0, colCount); } return new Matrix(result); } /** * 矩阵合并 * @param direction 合并方向,1为横向,2为竖向 * @param a * @return * @throws IllegalArgumentException */ public Matrix splice(int direction, Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if(direction == 1){ //横向拼接 if (matrixRowCount != a.getMatrixRowCount()) { throw new IllegalArgumentException("矩阵行数不一致,无法拼接"); } double[][] result = new double[matrixRowCount][matrixColCount + a.getMatrixColCount()]; for (int i = 0; i < matrixRowCount; i++) { System.arraycopy(matrix[i],0,result[i],0,matrixColCount); System.arraycopy(a.getMatrix()[i],0,result[i],matrixColCount,a.getMatrixColCount()); } return new Matrix(result); }else if(direction == 2){ //纵向拼接 if (matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵列数不一致,无法拼接"); } double[][] result = new double[matrixRowCount + a.getMatrixRowCount()][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { result[i] = matrix[i]; } for (int i = 0; i < a.getMatrixRowCount(); i++) { result[matrixRowCount + i] = a.getMatrix()[i]; } return new Matrix(result); }else{ throw new IllegalArgumentException("方向参数有误"); } } /** * 扩展矩阵 * @param direction 扩展方向,1为横向,2为竖向 * @param a * @return * @throws IllegalArgumentException */ public Matrix extend(int direction , int a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if(direction == 1){ //横向复制 double[][] result = new double[matrixRowCount][matrixColCount*a]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < a; j++) { System.arraycopy(matrix[i],0,result[i],j*matrixColCount,matrixColCount); } } return new Matrix(result); }else if(direction == 2){ //纵向复制 double[][] result = new double[matrixRowCount*a][matrixColCount]; for (int i = 0; i < matrixRowCount*a; i++) { result[i] = matrix[i%matrixRowCount]; } return new Matrix(result); }else{ throw new IllegalArgumentException("方向参数有误"); } } /** * 获取每列的平均值 * @return * @throws IllegalArgumentException */ public Matrix getColAvg() throws IllegalArgumentException { Matrix tmp = this.sumCol(); return tmp.divide(matrixRowCount); } /** * 矩阵行排序 * @param index 根据第几列的数进行行排序 * @param order 排序顺序,升序或降序 * @return * @throws IllegalArgumentException */ public void sort(int index,OrderEnum order) throws IllegalArgumentException{ switch (order){ case ASC: for (int i = 0; i < this.matrixRowCount; i++) { for (int j = 0; j < this.matrixRowCount - 1 - i; j++) { if (this.matrix[j][index] > this.matrix[j + 1][index]) { double[] tmp = this.matrix[j]; this.matrix[j] = this.matrix[j + 1]; this.matrix[j + 1] = tmp; } } } break; case DESC: for (int i = 0; i < this.matrixRowCount; i++) { for (int j = 0; j < this.matrixRowCount - 1 - i; j++) { if (this.matrix[j][index] < this.matrix[j + 1][index]) { double[] tmp = this.matrix[j]; this.matrix[j] = this.matrix[j + 1]; this.matrix[j + 1] = tmp; } } } break; default: } } /** * 判断是否是方阵 * 行列数相等,并且不等于0 * @return */ public boolean isSquareMatrix(){ return matrixColCount == matrixRowCount && matrixColCount != 0; } @Override public String toString() { StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append("\r\n"); for (int i = 0; i < matrixRowCount; i++) { stringBuilder.append("# "); for (int j = 0; j < matrixColCount; j++) { stringBuilder.append(matrix[i][j]).append("\t "); } stringBuilder.append("#\r\n"); } stringBuilder.append("\r\n"); return stringBuilder.toString(); } }
MatrixUtil工具类
package com.top.utils; import com.top.matrix.Matrix; import java.util.*; public class MatrixUtil { /** * 创建一个单位矩阵 * @param matrixRowCount 单位矩阵的纬度 * @return */ public static Matrix eye(int matrixRowCount){ double[][] result = new double[matrixRowCount][matrixRowCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixRowCount; j++) { if(i == j){ result[i][j] = 1; }else{ result[i][j] = 0; } } } return new Matrix(result); } /** * 求矩阵的逆 * 原理:AE=EA^-1 * @param a * @return * @throws Exception */ public static Matrix inv(Matrix a) throws Exception { if (!invable(a)) { throw new Exception("矩阵不可逆"); } // [a|E] Matrix b = a.splice(1, eye(a.getMatrixRowCount())); double[][] data = b.getMatrix(); int rowCount = b.getMatrixRowCount(); int colCount = b.getMatrixColCount(); //此处应用a的列数,为简化,直接用b的行数 for (int j = 0; j < rowCount; j++) { //若遇到0则交换两行 int notZeroRow = -2; if(data[j][j] == 0){ notZeroRow = -1; for (int l = j; l < rowCount; l++) { if (data[l][j] != 0) { notZeroRow = l; break; } } } if (notZeroRow == -1) { throw new Exception("矩阵不可逆"); }else if(notZeroRow != -2){ //交换j与notZeroRow两行 double[] tmp = data[j]; data[j] = data[notZeroRow]; data[notZeroRow] = tmp; } //将第data[j][j]化为1 if (data[j][j] != 1) { double multiple = data[j][j]; for (int colIdx = j; colIdx < colCount; colIdx++) { data[j][colIdx] /= multiple; } } //行与行相减 for (int i = 0; i < rowCount; i++) { if (i != j) { double multiple = data[i][j] / data[j][j]; //遍历行中的列 for (int k = j; k < colCount; k++) { data[i][k] = data[i][k] - multiple * data[j][k]; } } } } Matrix result = new Matrix(data); return result.subMatrix(0, rowCount, rowCount, rowCount); } /** * 求矩阵的伴随矩阵 * 原理:A*=|A|A^-1 * @param a * @return * @throws Exception */ public static Matrix adj(Matrix a) throws Exception { return inv(a).multiple(det(a)); } /** * 矩阵转成上三角矩阵 * @param a * @return * @throws Exception */ public static Matrix getTopTriangle(Matrix a) throws Exception { if (!a.isSquareMatrix()) { throw new Exception("不是方阵无法进行计算"); } int matrixHeight = a.getMatrixRowCount(); double[][] result = a.getMatrix(); //遍历列 for (int j = 0; j < matrixHeight; j++) { //遍历行 for (int i = j+1; i < matrixHeight; i++) { //若遇到0则交换两行 int notZeroRow = -2; if(result[j][j] == 0){ notZeroRow = -1; for (int l = i; l < matrixHeight; l++) { if (result[l][j] != 0) { notZeroRow = l; break; } } } if (notZeroRow == -1) { throw new Exception("矩阵不可逆"); }else if(notZeroRow != -2){ //交换j与notZeroRow两行 double[] tmp = result[j]; result[j] = result[notZeroRow]; result[notZeroRow] = tmp; } double multiple = result[i][j]/result[j][j]; //遍历行中的列 for (int k = j; k < matrixHeight; k++) { result[i][k] = result[i][k] - multiple * result[j][k]; } } } return new Matrix(result); } /** * 计算矩阵的行列式 * @param a * @return * @throws Exception */ public static double det(Matrix a) throws Exception { //将矩阵转成上三角矩阵 Matrix b = MatrixUtil.getTopTriangle(a); double result = 1; //计算矩阵行列式 for (int i = 0; i < b.getMatrixRowCount(); i++) { result *= b.getValOfIdx(i, i); } return result; } /** * 获取协方差矩阵 * @param a * @return * @throws Exception */ public static Matrix cov(Matrix a) throws Exception { if (a.getMatrix() == null) { throw new Exception("矩阵为空"); } Matrix avg = a.getColAvg().extend(2, a.getMatrixRowCount()); Matrix tmp = a.subtract(avg); return tmp.transpose().multiple(tmp).multiple(1/((double) a.getMatrixRowCount() -1)); } /** * 判断矩阵是否可逆 * 如果可转为上三角矩阵则可逆 * @param a * @return */ public static boolean invable(Matrix a) { try { getTopTriangle(a); return true; } catch (Exception e) { return false; } } /** * 数据归一化 * @param a 要归一化的数据 * @param normalizationMin 要归一化的区间下限 * @param normalizationMax 要归一化的区间上限 * @return */ public static Map<String, Object> normalize(Matrix a, double normalizationMin, double normalizationMax) throws Exception { HashMap<String, Object> result = new HashMap<>(); double[][] maxArr = new double[1][a.getMatrixColCount()]; double[][] minArr = new double[1][a.getMatrixColCount()]; double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; for (int i = 0; i < a.getMatrixColCount(); i++) { List tmp = new ArrayList(); for (int j = 0; j < a.getMatrixRowCount(); j++) { tmp.add(a.getValOfIdx(j,i)); } double max = (double) Collections.max(tmp); double min = (double) Collections.min(tmp); //数据归一化(注:若max与min均为0则不需要归一化) if (max != 0 || min != 0) { for (int j = 0; j < a.getMatrixRowCount(); j++) { res[j][i] = normalizationMin + (a.getValOfIdx(j,i) - min) / (max - min) * (normalizationMax - normalizationMin); } } maxArr[0][i] = max; minArr[0][i] = min; } result.put("max", new Matrix(maxArr)); result.put("min", new Matrix(minArr)); result.put("res", new Matrix(res)); return result; } /** * 反归一化 * @param a 要反归一化的数据 * @param normalizationMin 要反归一化的区间下限 * @param normalizationMax 要反归一化的区间上限 * @param dataMax 数据最大值 * @param dataMin 数据最小值 * @return */ public static Matrix inverseNormalize(Matrix a, double normalizationMax, double normalizationMin , Matrix dataMax,Matrix dataMin){ double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; for (int i = 0; i < a.getMatrixColCount(); i++) { //数据反归一化 if (dataMin.getValOfIdx(0,i) != 0 || dataMax.getValOfIdx(0,i) != 0) { for (int j = 0; j < a.getMatrixRowCount(); j++) { res[j][i] = dataMin.getValOfIdx(0,i) + (dataMax.getValOfIdx(0,i) - dataMin.getValOfIdx(0,i)) * (a.getValOfIdx(j,i) - normalizationMin) / (normalizationMax - normalizationMin); } } } return new Matrix(res); } }
ActivationFunction接口
public interface ActivationFunction { //计算值 double computeValue(double val); //计算导数 double computeDerivative(double val); }
Sigmoid
import java.io.Serializable; public class Sigmoid implements ActivationFunction, Serializable { @Override public double computeValue(double val) { return 1 / (1 + Math.exp(-val)); } @Override public double computeDerivative(double val) { return computeValue(val) * (1 - computeValue(val)); } }
BPParameter
包含了BP神经网络训练所需的参数
package com.top.bpnn; import java.io.Serializable; public class BPParameter implements Serializable { //输入层神经元个数 private int inputLayerNeuronCount = 3; //隐含层神经元个数 private int hiddenLayerNeuronCount = 3; //输出层神经元个数 private int outputLayerNeuronCount = 1; //归一化区间 private double normalizationMin = 0.2; private double normalizationMax = 0.8; //学习步长 private double step = 0.05; //动量因子 private double momentumFactor = 0.2; //激活函数 private ActivationFunction activationFunction = new Sigmoid(); //精度 private double precision = 0.000001; //最大循环次数 private int maxTimes = 1000000; public double getMomentumFactor() { return momentumFactor; } public void setMomentumFactor(double momentumFactor) { this.momentumFactor = momentumFactor; } public double getStep() { return step; } public void setStep(double step) { this.step = step; } public double getNormalizationMin() { return normalizationMin; } public void setNormalizationMin(double normalizationMin) { this.normalizationMin = normalizationMin; } public double getNormalizationMax() { return normalizationMax; } public void setNormalizationMax(double normalizationMax) { this.normalizationMax = normalizationMax; } public int getInputLayerNeuronCount() { return inputLayerNeuronCount; } public void setInputLayerNeuronCount(int inputLayerNeuronCount) { this.inputLayerNeuronCount = inputLayerNeuronCount; } public int getHiddenLayerNeuronCount() { return hiddenLayerNeuronCount; } public void setHiddenLayerNeuronCount(int hiddenLayerNeuronCount) { this.hiddenLayerNeuronCount = hiddenLayerNeuronCount; } public int getOutputLayerNeuronCount() { return outputLayerNeuronCount; } public void setOutputLayerNeuronCount(int outputLayerNeuronCount) { this.outputLayerNeuronCount = outputLayerNeuronCount; } public ActivationFunction getActivationFunction() { return activationFunction; } public void setActivationFunction(ActivationFunction activationFunction) { this.activationFunction = activationFunction; } public double getPrecision() { return precision; } public void setPrecision(double precision) { this.precision = precision; } public int getMaxTimes() { return maxTimes; } public void setMaxTimes(int maxTimes) { this.maxTimes = maxTimes; } }
BPModel
BP神经网络模型,包括权值与阈值及训练参数等属性
package com.top.bpnn; import com.top.matrix.Matrix; import java.io.Serializable; public class BPModel implements Serializable { //BP神经网络权值与阈值 private Matrix weightIJ; private Matrix b1; private Matrix weightJP; private Matrix b2; /*用于反归一化*/ private Matrix inputMax; private Matrix inputMin; private Matrix outputMax; private Matrix outputMin; /*BP神经网络训练参数*/ private BPParameter bpParameter; /*BP神经网络训练情况*/ private double error; private int times; public Matrix getWeightIJ() { return weightIJ; } public void setWeightIJ(Matrix weightIJ) { this.weightIJ = weightIJ; } public Matrix getB1() { return b1; } public void setB1(Matrix b1) { this.b1 = b1; } public Matrix getWeightJP() { return weightJP; } public void setWeightJP(Matrix weightJP) { this.weightJP = weightJP; } public Matrix getB2() { return b2; } public void setB2(Matrix b2) { this.b2 = b2; } public Matrix getInputMax() { return inputMax; } public void setInputMax(Matrix inputMax) { this.inputMax = inputMax; } public Matrix getInputMin() { return inputMin; } public void setInputMin(Matrix inputMin) { this.inputMin = inputMin; } public Matrix getOutputMax() { return outputMax; } public void setOutputMax(Matrix outputMax) { this.outputMax = outputMax; } public Matrix getOutputMin() { return outputMin; } public void setOutputMin(Matrix outputMin) { this.outputMin = outputMin; } public BPParameter getBpParameter() { return bpParameter; } public void setBpParameter(BPParameter bpParameter) { this.bpParameter = bpParameter; } public double getError() { return error; } public void setError(double error) { this.error = error; } public int getTimes() { return times; } public void setTimes(int times) { this.times = times; } }
BPNeuralNetworkFactory
BP神经网络工厂,包含了BP神经网络训练等功能
package com.top.bpnn; import com.top.matrix.Matrix; import com.top.utils.MatrixUtil; import java.util.*; public class BPNeuralNetworkFactory { /** * 训练BP神经网络模型 * @param bpParameter * @param inputAndOutput * @return */ public BPModel trainBP(BPParameter bpParameter, Matrix inputAndOutput) throws Exception { ActivationFunction activationFunction = bpParameter.getActivationFunction(); int inputCount = bpParameter.getInputLayerNeuronCount(); int hiddenCount = bpParameter.getHiddenLayerNeuronCount(); int outputCount = bpParameter.getOutputLayerNeuronCount(); double normalizationMin = bpParameter.getNormalizationMin(); double normalizationMax = bpParameter.getNormalizationMax(); double step = bpParameter.getStep(); double momentumFactor = bpParameter.getMomentumFactor(); double precision = bpParameter.getPrecision(); int maxTimes = bpParameter.getMaxTimes(); if(inputAndOutput.getMatrixColCount() != inputCount + outputCount){ throw new Exception("神经元个数不符,请修改"); } // 初始化权值 Matrix weightIJ = initWeight(inputCount, hiddenCount); Matrix weightJP = initWeight(hiddenCount, outputCount); // 初始化阈值 Matrix b1 = initThreshold(hiddenCount); Matrix b2 = initThreshold(outputCount); // 动量项 Matrix deltaWeightIJ0 = new Matrix(inputCount, hiddenCount); Matrix deltaWeightJP0 = new Matrix(hiddenCount, outputCount); Matrix deltaB10 = new Matrix(1, hiddenCount); Matrix deltaB20 = new Matrix(1, outputCount); // 截取输入矩阵和输出矩阵 Matrix input = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),0,inputCount); Matrix output = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),inputCount,outputCount); // 归一化 Map<String,Object> inputAfterNormalize = MatrixUtil.normalize(input, normalizationMin, normalizationMax); input = (Matrix) inputAfterNormalize.get("res"); Map<String,Object> outputAfterNormalize = MatrixUtil.normalize(output, normalizationMin, normalizationMax); output = (Matrix) outputAfterNormalize.get("res"); int times = 1; double E = 0;//误差 while (times < maxTimes) { /*-----------------正向传播---------------------*/ // 隐含层输入 Matrix jIn = input.multiple(weightIJ); // 扩充阈值 Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount()); // 加上阈值 jIn = jIn.plus(b1Copy); // 隐含层输出 Matrix jOut = computeValue(jIn,activationFunction); // 输出层输入 Matrix pIn = jOut.multiple(weightJP); // 扩充阈值 Matrix b2Copy = b2.extend(2, pIn.getMatrixRowCount()); // 加上阈值 pIn = pIn.plus(b2Copy); // 输出层输出 Matrix pOut = computeValue(pIn,activationFunction); // 计算误差 Matrix e = output.subtract(pOut); E = computeE(e);//误差 // 判断是否符合精度 if (Math.abs(E) <= precision) { System.out.println("满足精度"); break; } /*-----------------反向传播---------------------*/ // J与P之间权值修正量 Matrix deltaWeightJP = e.multiple(step); deltaWeightJP = deltaWeightJP.pointMultiple(computeDerivative(pIn,activationFunction)); deltaWeightJP = deltaWeightJP.transpose().multiple(jOut); deltaWeightJP = deltaWeightJP.transpose(); // P层神经元阈值修正量 Matrix deltaThresholdP = e.multiple(step); deltaThresholdP = deltaThresholdP.transpose().multiple(computeDerivative(pIn, activationFunction)); // I与J之间的权值修正量 Matrix deltaO = e.pointMultiple(computeDerivative(pIn,activationFunction)); Matrix tmp = weightJP.multiple(deltaO.transpose()).transpose(); Matrix deltaWeightIJ = tmp.pointMultiple(computeDerivative(jIn, activationFunction)); deltaWeightIJ = input.transpose().multiple(deltaWeightIJ); deltaWeightIJ = deltaWeightIJ.multiple(step); // J层神经元阈值修正量 Matrix deltaThresholdJ = tmp.transpose().multiple(computeDerivative(jIn, activationFunction)); deltaThresholdJ = deltaThresholdJ.multiple(-step); if (times == 1) { // 更新权值与阈值 weightIJ = weightIJ.plus(deltaWeightIJ); weightJP = weightJP.plus(deltaWeightJP); b1 = b1.plus(deltaThresholdJ); b2 = b2.plus(deltaThresholdP); }else{ // 加动量项 weightIJ = weightIJ.plus(deltaWeightIJ).plus(deltaWeightIJ0.multiple(momentumFactor)); weightJP = weightJP.plus(deltaWeightJP).plus(deltaWeightJP0.multiple(momentumFactor)); b1 = b1.plus(deltaThresholdJ).plus(deltaB10.multiple(momentumFactor)); b2 = b2.plus(deltaThresholdP).plus(deltaB20.multiple(momentumFactor)); } deltaWeightIJ0 = deltaWeightIJ; deltaWeightJP0 = deltaWeightJP; deltaB10 = deltaThresholdJ; deltaB20 = deltaThresholdP; times++; } // BP神经网络的输出 BPModel result = new BPModel(); result.setInputMax((Matrix) inputAfterNormalize.get("max")); result.setInputMin((Matrix) inputAfterNormalize.get("min")); result.setOutputMax((Matrix) outputAfterNormalize.get("max")); result.setOutputMin((Matrix) outputAfterNormalize.get("min")); result.setWeightIJ(weightIJ); result.setWeightJP(weightJP); result.setB1(b1); result.setB2(b2); result.setError(E); result.setTimes(times); result.setBpParameter(bpParameter); System.out.println("循环次数:" + times + ",误差:" + E); return result; } /** * 计算BP神经网络的值 * @param bpModel * @param input * @return */ public Matrix computeBP(BPModel bpModel,Matrix input) throws Exception { if (input.getMatrixColCount() != bpModel.getBpParameter().getInputLayerNeuronCount()) { throw new Exception("输入矩阵纬度有误"); } ActivationFunction activationFunction = bpModel.getBpParameter().getActivationFunction(); Matrix weightIJ = bpModel.getWeightIJ(); Matrix weightJP = bpModel.getWeightJP(); Matrix b1 = bpModel.getB1(); Matrix b2 = bpModel.getB2(); double[][] normalizedInput = new double[input.getMatrixRowCount()][input.getMatrixColCount()]; for (int i = 0; i < input.getMatrixRowCount(); i++) { for (int j = 0; j < input.getMatrixColCount(); j++) { normalizedInput[i][j] = bpModel.getBpParameter().getNormalizationMin() + (input.getValOfIdx(i,j) - bpModel.getInputMin().getValOfIdx(0,j)) / (bpModel.getInputMax().getValOfIdx(0,j) - bpModel.getInputMin().getValOfIdx(0,j)) * (bpModel.getBpParameter().getNormalizationMax() - bpModel.getBpParameter().getNormalizationMin()); } } Matrix normalizedInputMatrix = new Matrix(normalizedInput); Matrix jIn = normalizedInputMatrix.multiple(weightIJ); // 扩充阈值 Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount()); // 加上阈值 jIn = jIn.plus(b1Copy); // 隐含层输出 Matrix jOut = computeValue(jIn,activationFunction); // 输出层输入 Matrix pIn = jOut.multiple(weightJP); // 扩充阈值 Matrix b2Copy = b2.extend(2,pIn.getMatrixRowCount()); // 加上阈值 pIn = pIn.plus(b2Copy); // 输出层输出 Matrix pOut = computeValue(pIn,activationFunction); // 反归一化 return MatrixUtil.inverseNormalize(pOut, bpModel.getBpParameter().getNormalizationMax(), bpModel.getBpParameter().getNormalizationMin(), bpModel.getOutputMax(), bpModel.getOutputMin()); } // 初始化权值 private Matrix initWeight(int x,int y){ Random random=new Random(); double[][] weight = new double[x][y]; for (int i = 0; i < x; i++) { for (int j = 0; j < y; j++) { weight[i][j] = 2*random.nextDouble()-1; } } return new Matrix(weight); } // 初始化阈值 private Matrix initThreshold(int x){ Random random = new Random(); double[][] result = new double[1][x]; for (int i = 0; i < x; i++) { result[0][i] = 2*random.nextDouble()-1; } return new Matrix(result); } /** * 计算激活函数的值 * @param a * @return */ private Matrix computeValue(Matrix a, ActivationFunction activationFunction) throws Exception { if (a.getMatrix() == null) { throw new Exception("参数值为空"); } double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; for (int i = 0; i < a.getMatrixRowCount(); i++) { for (int j = 0; j < a.getMatrixColCount(); j++) { result[i][j] = activationFunction.computeValue(a.getValOfIdx(i,j)); } } return new Matrix(result); } /** * 激活函数导数的值 * @param a * @return */ private Matrix computeDerivative(Matrix a , ActivationFunction activationFunction) throws Exception { if (a.getMatrix() == null) { throw new Exception("参数值为空"); } double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; for (int i = 0; i < a.getMatrixRowCount(); i++) { for (int j = 0; j < a.getMatrixColCount(); j++) { result[i][j] = activationFunction.computeDerivative(a.getValOfIdx(i,j)); } } return new Matrix(result); } /** * 计算误差 * @param e * @return */ private double computeE(Matrix e){ e = e.square(); return 0.5*e.sumAll(); } }
使用方式
思路就是创建BPNeuralNetworkFactory对象,并传入BPParameter对象,调用BPNeuralNetworkFactory的trainBP(BPParameter bpParameter, Matrix inputAndOutput)方法,返回一个BPModel对象,可以使用BPNeuralNetworkFactory的序列化方法,将其序列化到本地,或者将其放到缓存中,使用时直接从本地反序列化获取到BPModel对象,调用BPNeuralNetworkFactory的computeBP(BPModel bpModel,Matrix input)方法,即可获取计算值。
使用详情请看:https://github.com/ineedahouse/top-algorithm-set-doc/blob/master/doc/bpnn/BPNeuralNetwork.md
源码github地址
https://github.com/ineedahouse/top-algorithm-set
对您有帮助的话,请点个Star~谢谢