JAVA实现BP神经网络算法

工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测。

简介

BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法进行权值与阈值的调整。在20世纪80年代,几位不同的学者分别开发出了用于训练多层感知机的反向传播算法,David Rumelhart和James McClelland提出的反向传播算法是最具影响力的。其包含BP的两大主要过程,即工作信号的正向传播与误差信号的反向传播,分别负责了神经网络中输出的计算与权值和阈值更新。工作信号的正向传播是通过计算得到BP神经网络的实际输出,误差信号的反向传播是由后往前逐层修正权值与阈值,为了使实际输出更接近期望输出。

​ (1)工作信号正向传播。输入信号从输入层进入,通过突触进入隐含层神经元,经传递函数运算后,传递到输出层,并且在输出层计算出输出信号传出。当工作信号正向传播时,权值与阈值固定不变,神经网络中每层的状态只与前一层的净输出、权值和阈值有关。若正向传播在输出层获得到期望的输出,则学习结束,并保留当前的权值与阈值;若正向传播在输出层得不到期望的输出,则在误差信号的反向传播中修正权值与阈值。

​ (2)误差信号反向传播。在工作信号正向传播后若得不到期望的输出,则通过计算误差信号进行反向传播,通过计算BP神经网络的实际输出与期望输出之间的差值作为误差信号,并且由神经网络的输出层,逐层向输入层传播。在此过程中,每向前传播一层,就对该层的权值与阈值进行修改,由此一直向前传播直至输入层,该过程是为了使神经网络的结果与期望的结果更相近。

​ 当进行一次正向传播和反向传播后,若误差仍不能达到要求,则该过程继续下去,直至误差满足精度,或者满足迭代次数等其他设置的结束条件。

推导请见 https://zh.wikipedia.org/wiki/%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95

BPNN结构

该BPNN为单输入层单隐含层单输出层结构

项目结构

介绍一些用到的类

  • ActivationFunction:激活函数的接口
  • BPModel:BP模型实体类
  • BPNeuralNetworkFactory:BP神经网络工厂,包括训练BP神经网络,计算,序列化等功能
  • BPParameter:BP神经网络参数实体类
  • Matrix:矩阵实体类
  • Sigmoid:Sigmoid传输函数,实现了ActivationFunction接口
  • MatrixUtil:矩阵工具类

实现代码

Matrix实体类

模拟了矩阵的基本运算方法。

package com.top.matrix;

import com.top.constants.OrderEnum;

import java.io.Serializable;

public class Matrix implements Serializable {
    private double[][] matrix;
    //矩阵列数
    private int matrixColCount;
    //矩阵行数
    private int matrixRowCount;

    /**
     * 构造一个空矩阵
     */
    public Matrix() {
        this.matrix = null;
        this.matrixColCount = 0;
        this.matrixRowCount = 0;
    }

    /**
     * 构造一个matrix矩阵
     * @param matrix
     */
    public Matrix(double[][] matrix) {
        this.matrix = matrix;
        this.matrixRowCount = matrix.length;
        this.matrixColCount = matrix[0].length;
    }

    /**
     * 构造一个rowCount行colCount列值为0的矩阵
     * @param rowCount
     * @param colCount
     */
    public Matrix(int rowCount,int colCount) {
        double[][] matrix = new double[rowCount][colCount];
        for (int i = 0; i < rowCount; i++) {
            for (int j = 0; j < colCount; j++) {
                matrix[i][j] = 0;
            }
        }
        this.matrix = matrix;
        this.matrixRowCount = rowCount;
        this.matrixColCount = colCount;
    }

    /**
     * 构造一个rowCount行colCount列值为val的矩阵
     * @param val
     * @param rowCount
     * @param colCount
     */
    public Matrix(double val,int rowCount,int colCount) {
        double[][] matrix = new double[rowCount][colCount];
        for (int i = 0; i < rowCount; i++) {
            for (int j = 0; j < colCount; j++) {
                matrix[i][j] = val;
            }
        }
        this.matrix = matrix;
        this.matrixRowCount = rowCount;
        this.matrixColCount = colCount;
    }

    public double[][] getMatrix() {
        return matrix;
    }

    public void setMatrix(double[][] matrix) {
        this.matrix = matrix;
        this.matrixRowCount = matrix.length;
        this.matrixColCount = matrix[0].length;
    }

    public int getMatrixColCount() {
        return matrixColCount;
    }

    public int getMatrixRowCount() {
        return matrixRowCount;
    }

    /**
     * 获取矩阵指定位置的值
     *
     * @param x
     * @param y
     * @return
     */
    public double getValOfIdx(int x, int y) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (x > matrixRowCount - 1) {
            throw new IllegalArgumentException("索引x越界");
        }
        if (y > matrixColCount - 1) {
            throw new IllegalArgumentException("索引y越界");
        }
        return matrix[x][y];
    }

    /**
     * 获取矩阵指定行
     *
     * @param x
     * @return
     */
    public Matrix getRowOfIdx(int x) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (x > matrixRowCount - 1) {
            throw new IllegalArgumentException("索引x越界");
        }
        double[][] result = new double[1][matrixColCount];
        result[0] = matrix[x];
        return new Matrix(result);
    }

    /**
     * 获取矩阵指定列
     *
     * @param y
     * @return
     */
    public Matrix getColOfIdx(int y) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (y > matrixColCount - 1) {
            throw new IllegalArgumentException("索引y越界");
        }
        double[][] result = new double[matrixRowCount][1];
        for (int i = 0; i < matrixRowCount; i++) {
            result[i][0] = matrix[i][y];
        }
        return new Matrix(result);
    }

    /**
     * 设置矩阵中x,y位置元素的值
     * @param x
     * @param y
     * @param val
     */
    public void setValue(int x, int y, double val) {
        if (x > this.matrixRowCount - 1) {
            throw new IllegalArgumentException("行索引越界");
        }
        if (y > this.matrixColCount - 1) {
            throw new IllegalArgumentException("列索引越界");
        }
        this.matrix[x][y] = val;
    }

    /**
     * 矩阵乘矩阵
     *
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix multiple(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixColCount != a.getMatrixRowCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][a.getMatrixColCount()];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < a.getMatrixColCount(); j++) {
                for (int k = 0; k < matrixColCount; k++) {
                    result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j];
                }
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵乘一个数字
     *
     * @param a
     * @return
     */
    public Matrix multiple(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] * a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵点乘
     *
     * @param a
     * @return
     */
    public Matrix pointMultiple(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] * a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵除一个数字
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix divide(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] / a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵加法
     *
     * @param a
     * @return
     */
    public Matrix plus(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] + a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵加一个数字
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix plus(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] + a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵减法
     *
     * @param a
     * @return
     */
    public Matrix subtract(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] - a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵减一个数字
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix subtract(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] - a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵行求和
     *
     * @return
     */
    public Matrix sumRow() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][1];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][0] += matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵列求和
     *
     * @return
     */
    public Matrix sumCol() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[1][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[0][j] += matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵所有元素求和
     *
     * @return
     */
    public double sumAll() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double result = 0;
        for (double[] doubles : matrix) {
            for (int j = 0; j < matrixColCount; j++) {
                result += doubles[j];
            }
        }
        return result;
    }

    /**
     * 矩阵所有元素求平方
     *
     * @return
     */
    public Matrix square() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] * matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵所有元素求N次方
     *
     * @return
     */
    public Matrix pow(double n) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = Math.pow(matrix[i][j],n);
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵转置
     *
     * @return
     */
    public Matrix transpose() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixColCount][matrixRowCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[j][i] = matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 截取矩阵
     * @param startRowIndex 开始行索引
     * @param rowCount   截取行数
     * @param startColIndex 开始列索引
     * @param colCount   截取列数
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix subMatrix(int startRowIndex,int rowCount,int startColIndex,int colCount) throws IllegalArgumentException {
        if (startRowIndex + rowCount > matrixRowCount) {
            throw new IllegalArgumentException("行索引越界");
        }
        if (startColIndex + colCount> matrixColCount) {
            throw new IllegalArgumentException("列索引越界");
        }
        double[][] result = new double[rowCount][colCount];
        for (int i = startRowIndex; i < startRowIndex + rowCount; i++) {
            if (startColIndex + colCount - startColIndex >= 0)
                System.arraycopy(matrix[i], startColIndex, result[i - startRowIndex], 0, colCount);
        }
        return new Matrix(result);
    }

    /**
     * 矩阵合并
     * @param direction 合并方向,1为横向,2为竖向
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix splice(int direction, Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if(direction == 1){
            //横向拼接
            if (matrixRowCount != a.getMatrixRowCount()) {
                throw new IllegalArgumentException("矩阵行数不一致,无法拼接");
            }
            double[][] result = new double[matrixRowCount][matrixColCount + a.getMatrixColCount()];
            for (int i = 0; i < matrixRowCount; i++) {
                System.arraycopy(matrix[i],0,result[i],0,matrixColCount);
                System.arraycopy(a.getMatrix()[i],0,result[i],matrixColCount,a.getMatrixColCount());
            }
            return new Matrix(result);
        }else if(direction == 2){
            //纵向拼接
            if (matrixColCount != a.getMatrixColCount()) {
                throw new IllegalArgumentException("矩阵列数不一致,无法拼接");
            }
            double[][] result = new double[matrixRowCount + a.getMatrixRowCount()][matrixColCount];
            for (int i = 0; i < matrixRowCount; i++) {
                result[i] = matrix[i];
            }
            for (int i = 0; i < a.getMatrixRowCount(); i++) {
                result[matrixRowCount + i] = a.getMatrix()[i];
            }
            return new Matrix(result);
        }else{
            throw new IllegalArgumentException("方向参数有误");
        }
    }
    /**
     * 扩展矩阵
     * @param direction 扩展方向,1为横向,2为竖向
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix extend(int direction , int a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if(direction == 1){
            //横向复制
            double[][] result = new double[matrixRowCount][matrixColCount*a];
            for (int i = 0; i < matrixRowCount; i++) {
                for (int j = 0; j < a; j++) {
                    System.arraycopy(matrix[i],0,result[i],j*matrixColCount,matrixColCount);
                }
            }
            return new Matrix(result);
        }else if(direction == 2){
            //纵向复制
            double[][] result = new double[matrixRowCount*a][matrixColCount];
            for (int i = 0; i < matrixRowCount*a; i++) {
                result[i] = matrix[i%matrixRowCount];
            }
            return new Matrix(result);
        }else{
            throw new IllegalArgumentException("方向参数有误");
        }
    }
    /**
     * 获取每列的平均值
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix getColAvg() throws IllegalArgumentException {
        Matrix tmp = this.sumCol();
        return tmp.divide(matrixRowCount);
    }

    /**
     * 矩阵行排序
     * @param index 根据第几列的数进行行排序
     * @param order 排序顺序,升序或降序
     * @return
     * @throws IllegalArgumentException
     */
    public void sort(int index,OrderEnum order) throws IllegalArgumentException{
        switch (order){
            case ASC:
                for (int i = 0; i < this.matrixRowCount; i++) {
                    for (int j = 0; j < this.matrixRowCount - 1 - i; j++) {
                        if (this.matrix[j][index] > this.matrix[j + 1][index]) {
                            double[] tmp = this.matrix[j];
                            this.matrix[j] = this.matrix[j + 1];
                            this.matrix[j + 1] = tmp;
                        }
                    }
                }
                break;
            case DESC:
                for (int i = 0; i < this.matrixRowCount; i++) {
                    for (int j = 0; j < this.matrixRowCount - 1 - i; j++) {
                        if (this.matrix[j][index] < this.matrix[j + 1][index]) {
                            double[] tmp = this.matrix[j];
                            this.matrix[j] = this.matrix[j + 1];
                            this.matrix[j + 1] = tmp;
                        }
                    }
                }
                break;
            default:

        }

    }

    /**
     * 判断是否是方阵
     * 行列数相等,并且不等于0
     * @return
     */
    public boolean isSquareMatrix(){
        return matrixColCount == matrixRowCount && matrixColCount != 0;
    }

    @Override
    public String toString() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("\r\n");
        for (int i = 0; i < matrixRowCount; i++) {
            stringBuilder.append("# ");
            for (int j = 0; j < matrixColCount; j++) {
                stringBuilder.append(matrix[i][j]).append("\t ");
            }
            stringBuilder.append("#\r\n");
        }
        stringBuilder.append("\r\n");
        return stringBuilder.toString();
    }
}
Matrix代码

MatrixUtil工具类

package com.top.utils;

import com.top.matrix.Matrix;

import java.util.*;

public class MatrixUtil {
    /**
     * 创建一个单位矩阵
     * @param matrixRowCount 单位矩阵的纬度
     * @return
     */
    public static Matrix eye(int matrixRowCount){
        double[][] result = new double[matrixRowCount][matrixRowCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixRowCount; j++) {
                if(i == j){
                    result[i][j] = 1;
                }else{
                    result[i][j] = 0;
                }
            }
        }
        return new Matrix(result);
    }

    /**
     * 求矩阵的逆
     * 原理:AE=EA^-1
     * @param a
     * @return
     * @throws Exception
     */
    public static Matrix inv(Matrix a) throws Exception {
        if (!invable(a)) {
            throw new Exception("矩阵不可逆");
        }
        // [a|E]
        Matrix b = a.splice(1, eye(a.getMatrixRowCount()));
        double[][] data = b.getMatrix();
        int rowCount = b.getMatrixRowCount();
        int colCount = b.getMatrixColCount();
        //此处应用a的列数,为简化,直接用b的行数
        for (int j = 0; j < rowCount; j++) {
            //若遇到0则交换两行
            int notZeroRow = -2;
            if(data[j][j] == 0){
                notZeroRow = -1;
                for (int l = j; l < rowCount; l++) {
                    if (data[l][j] != 0) {
                        notZeroRow = l;
                        break;
                    }
                }
            }
            if (notZeroRow == -1) {
                throw new Exception("矩阵不可逆");
            }else if(notZeroRow != -2){
                //交换j与notZeroRow两行
                double[] tmp = data[j];
                data[j] = data[notZeroRow];
                data[notZeroRow] = tmp;
            }
            //将第data[j][j]化为1
            if (data[j][j] != 1) {
                double multiple = data[j][j];
                for (int colIdx = j; colIdx < colCount; colIdx++) {
                    data[j][colIdx] /= multiple;
                }
            }
            //行与行相减
            for (int i = 0; i < rowCount; i++) {
                if (i != j) {
                    double multiple = data[i][j] / data[j][j];
                    //遍历行中的列
                    for (int k = j; k < colCount; k++) {
                        data[i][k] = data[i][k] - multiple * data[j][k];
                    }
                }
            }
        }
        Matrix result = new Matrix(data);
        return result.subMatrix(0, rowCount, rowCount, rowCount);
    }

    /**
     * 求矩阵的伴随矩阵
     * 原理:A*=|A|A^-1
     * @param a
     * @return
     * @throws Exception
     */
    public static Matrix adj(Matrix a) throws Exception {
        return inv(a).multiple(det(a));
    }

    /**
     * 矩阵转成上三角矩阵
     * @param a
     * @return
     * @throws Exception
     */
    public static Matrix getTopTriangle(Matrix a) throws Exception {
        if (!a.isSquareMatrix()) {
            throw new Exception("不是方阵无法进行计算");
        }
        int matrixHeight = a.getMatrixRowCount();
        double[][] result = a.getMatrix();
        //遍历列
        for (int j = 0; j < matrixHeight; j++) {
            //遍历行
            for (int i = j+1; i < matrixHeight; i++) {
                //若遇到0则交换两行
                int notZeroRow = -2;
                if(result[j][j] == 0){
                    notZeroRow = -1;
                    for (int l = i; l < matrixHeight; l++) {
                        if (result[l][j] != 0) {
                            notZeroRow = l;
                            break;
                        }
                    }
                }
                if (notZeroRow == -1) {
                    throw new Exception("矩阵不可逆");
                }else if(notZeroRow != -2){
                    //交换j与notZeroRow两行
                    double[] tmp = result[j];
                    result[j] = result[notZeroRow];
                    result[notZeroRow] = tmp;
                }

                double multiple = result[i][j]/result[j][j];
                //遍历行中的列
                for (int k = j; k < matrixHeight; k++) {
                    result[i][k] = result[i][k] - multiple * result[j][k];
                }
            }
        }
        return new Matrix(result);
    }

    /**
     * 计算矩阵的行列式
     * @param a
     * @return
     * @throws Exception
     */
    public static double det(Matrix a) throws Exception {
        //将矩阵转成上三角矩阵
        Matrix b = MatrixUtil.getTopTriangle(a);
        double result = 1;
        //计算矩阵行列式
        for (int i = 0; i < b.getMatrixRowCount(); i++) {
            result *= b.getValOfIdx(i, i);
        }
        return result;
    }
    /**
     * 获取协方差矩阵
     * @param a
     * @return
     * @throws Exception
     */
    public static Matrix cov(Matrix a) throws Exception {
        if (a.getMatrix() == null) {
            throw new Exception("矩阵为空");
        }
        Matrix avg = a.getColAvg().extend(2, a.getMatrixRowCount());
        Matrix tmp = a.subtract(avg);
        return tmp.transpose().multiple(tmp).multiple(1/((double) a.getMatrixRowCount() -1));
    }

    /**
     * 判断矩阵是否可逆
     * 如果可转为上三角矩阵则可逆
     * @param a
     * @return
     */
    public static boolean invable(Matrix a) {
        try {
            getTopTriangle(a);
            return true;
        } catch (Exception e) {
            return false;
        }
    }


    /**
     * 数据归一化
     * @param a 要归一化的数据
     * @param normalizationMin  要归一化的区间下限
     * @param normalizationMax  要归一化的区间上限
     * @return
     */
    public static Map<String, Object> normalize(Matrix a, double normalizationMin, double normalizationMax) throws Exception {
        HashMap<String, Object> result = new HashMap<>();
        double[][] maxArr = new double[1][a.getMatrixColCount()];
        double[][] minArr = new double[1][a.getMatrixColCount()];
        double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
        for (int i = 0; i < a.getMatrixColCount(); i++) {
            List tmp = new ArrayList();
            for (int j = 0; j < a.getMatrixRowCount(); j++) {
                tmp.add(a.getValOfIdx(j,i));
            }
            double max = (double) Collections.max(tmp);
            double min = (double) Collections.min(tmp);
            //数据归一化(注:若max与min均为0则不需要归一化)
            if (max != 0 || min != 0) {
                for (int j = 0; j < a.getMatrixRowCount(); j++) {
                    res[j][i] = normalizationMin + (a.getValOfIdx(j,i) - min) / (max - min) * (normalizationMax - normalizationMin);
                }
            }
            maxArr[0][i] = max;
            minArr[0][i] = min;
        }
        result.put("max", new Matrix(maxArr));
        result.put("min", new Matrix(minArr));
        result.put("res", new Matrix(res));
        return result;
    }

    /**
     * 反归一化
     * @param a 要反归一化的数据
     * @param normalizationMin 要反归一化的区间下限
     * @param normalizationMax 要反归一化的区间上限
     * @param dataMax   数据最大值
     * @param dataMin   数据最小值
     * @return
     */
    public static Matrix inverseNormalize(Matrix a, double normalizationMax, double normalizationMin , Matrix dataMax,Matrix dataMin){
        double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
        for (int i = 0; i < a.getMatrixColCount(); i++) {
            //数据反归一化
            if (dataMin.getValOfIdx(0,i) != 0 || dataMax.getValOfIdx(0,i) != 0) {
                for (int j = 0; j < a.getMatrixRowCount(); j++) {
                    res[j][i] = dataMin.getValOfIdx(0,i) + (dataMax.getValOfIdx(0,i) - dataMin.getValOfIdx(0,i)) * (a.getValOfIdx(j,i) - normalizationMin) / (normalizationMax - normalizationMin);
                }
            }
        }
        return new Matrix(res);
    }
}
MatrixUtil工具类

ActivationFunction接口

public interface ActivationFunction {
    //计算值
    double computeValue(double val);
    //计算导数
    double computeDerivative(double val);
}
ActivationFunction代码

Sigmoid

import java.io.Serializable;

public class Sigmoid implements ActivationFunction, Serializable {
    @Override
    public double computeValue(double val) {
        return 1 / (1 + Math.exp(-val));
    }

    @Override
    public double computeDerivative(double val) {
        return computeValue(val) * (1 - computeValue(val));
    }
}
Sigmoid代码

BPParameter

包含了BP神经网络训练所需的参数

package com.top.bpnn;

import java.io.Serializable;

public class BPParameter implements Serializable {

    //输入层神经元个数
    private int inputLayerNeuronCount = 3;
    //隐含层神经元个数
    private int hiddenLayerNeuronCount = 3;
    //输出层神经元个数
    private int outputLayerNeuronCount = 1;
    //归一化区间
    private double normalizationMin = 0.2;
    private double normalizationMax = 0.8;
    //学习步长
    private double step = 0.05;
    //动量因子
    private double momentumFactor = 0.2;
    //激活函数
    private ActivationFunction activationFunction = new Sigmoid();
    //精度
    private double precision = 0.000001;
    //最大循环次数
    private int maxTimes = 1000000;

    public double getMomentumFactor() {
        return momentumFactor;
    }

    public void setMomentumFactor(double momentumFactor) {
        this.momentumFactor = momentumFactor;
    }

    public double getStep() {
        return step;
    }

    public void setStep(double step) {
        this.step = step;
    }

    public double getNormalizationMin() {
        return normalizationMin;
    }

    public void setNormalizationMin(double normalizationMin) {
        this.normalizationMin = normalizationMin;
    }

    public double getNormalizationMax() {
        return normalizationMax;
    }

    public void setNormalizationMax(double normalizationMax) {
        this.normalizationMax = normalizationMax;
    }

    public int getInputLayerNeuronCount() {
        return inputLayerNeuronCount;
    }

    public void setInputLayerNeuronCount(int inputLayerNeuronCount) {
        this.inputLayerNeuronCount = inputLayerNeuronCount;
    }

    public int getHiddenLayerNeuronCount() {
        return hiddenLayerNeuronCount;
    }

    public void setHiddenLayerNeuronCount(int hiddenLayerNeuronCount) {
        this.hiddenLayerNeuronCount = hiddenLayerNeuronCount;
    }

    public int getOutputLayerNeuronCount() {
        return outputLayerNeuronCount;
    }

    public void setOutputLayerNeuronCount(int outputLayerNeuronCount) {
        this.outputLayerNeuronCount = outputLayerNeuronCount;
    }

    public ActivationFunction getActivationFunction() {
        return activationFunction;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public double getPrecision() {
        return precision;
    }

    public void setPrecision(double precision) {
        this.precision = precision;
    }

    public int getMaxTimes() {
        return maxTimes;
    }

    public void setMaxTimes(int maxTimes) {
        this.maxTimes = maxTimes;
    }
}
BPParameter代码

BPModel

BP神经网络模型,包括权值与阈值及训练参数等属性

package com.top.bpnn;

import com.top.matrix.Matrix;

import java.io.Serializable;

public class BPModel implements Serializable {
    //BP神经网络权值与阈值
    private Matrix weightIJ;
    private Matrix b1;
    private Matrix weightJP;
    private Matrix b2;
    /*用于反归一化*/
    private Matrix inputMax;
    private Matrix inputMin;
    private Matrix outputMax;
    private Matrix outputMin;
    /*BP神经网络训练参数*/
    private BPParameter bpParameter;
    /*BP神经网络训练情况*/
    private double error;
    private int times;

    public Matrix getWeightIJ() {
        return weightIJ;
    }

    public void setWeightIJ(Matrix weightIJ) {
        this.weightIJ = weightIJ;
    }

    public Matrix getB1() {
        return b1;
    }

    public void setB1(Matrix b1) {
        this.b1 = b1;
    }

    public Matrix getWeightJP() {
        return weightJP;
    }

    public void setWeightJP(Matrix weightJP) {
        this.weightJP = weightJP;
    }

    public Matrix getB2() {
        return b2;
    }

    public void setB2(Matrix b2) {
        this.b2 = b2;
    }

    public Matrix getInputMax() {
        return inputMax;
    }

    public void setInputMax(Matrix inputMax) {
        this.inputMax = inputMax;
    }

    public Matrix getInputMin() {
        return inputMin;
    }

    public void setInputMin(Matrix inputMin) {
        this.inputMin = inputMin;
    }

    public Matrix getOutputMax() {
        return outputMax;
    }

    public void setOutputMax(Matrix outputMax) {
        this.outputMax = outputMax;
    }

    public Matrix getOutputMin() {
        return outputMin;
    }

    public void setOutputMin(Matrix outputMin) {
        this.outputMin = outputMin;
    }

    public BPParameter getBpParameter() {
        return bpParameter;
    }

    public void setBpParameter(BPParameter bpParameter) {
        this.bpParameter = bpParameter;
    }

    public double getError() {
        return error;
    }

    public void setError(double error) {
        this.error = error;
    }

    public int getTimes() {
        return times;
    }

    public void setTimes(int times) {
        this.times = times;
    }
}
BPModel代码

BPNeuralNetworkFactory

BP神经网络工厂,包含了BP神经网络训练等功能

package com.top.bpnn;

import com.top.matrix.Matrix;
import com.top.utils.MatrixUtil;

import java.util.*;

public class BPNeuralNetworkFactory {
    /**
     * 训练BP神经网络模型
     * @param bpParameter
     * @param inputAndOutput
     * @return
     */
    public BPModel trainBP(BPParameter bpParameter, Matrix inputAndOutput) throws Exception {

        ActivationFunction activationFunction = bpParameter.getActivationFunction();
        int inputCount = bpParameter.getInputLayerNeuronCount();
        int hiddenCount = bpParameter.getHiddenLayerNeuronCount();
        int outputCount = bpParameter.getOutputLayerNeuronCount();
        double normalizationMin = bpParameter.getNormalizationMin();
        double normalizationMax = bpParameter.getNormalizationMax();
        double step = bpParameter.getStep();
        double momentumFactor = bpParameter.getMomentumFactor();
        double precision = bpParameter.getPrecision();
        int maxTimes = bpParameter.getMaxTimes();

        if(inputAndOutput.getMatrixColCount() != inputCount + outputCount){
            throw new Exception("神经元个数不符,请修改");
        }
        // 初始化权值
        Matrix weightIJ = initWeight(inputCount, hiddenCount);
        Matrix weightJP = initWeight(hiddenCount, outputCount);

        // 初始化阈值
        Matrix b1 = initThreshold(hiddenCount);
        Matrix b2 = initThreshold(outputCount);

        // 动量项
        Matrix deltaWeightIJ0 = new Matrix(inputCount, hiddenCount);
        Matrix deltaWeightJP0 = new Matrix(hiddenCount, outputCount);
        Matrix deltaB10 = new Matrix(1, hiddenCount);
        Matrix deltaB20 = new Matrix(1, outputCount);

        // 截取输入矩阵和输出矩阵
        Matrix input = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),0,inputCount);
        Matrix output = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),inputCount,outputCount);

        // 归一化
        Map<String,Object> inputAfterNormalize = MatrixUtil.normalize(input, normalizationMin, normalizationMax);
        input = (Matrix) inputAfterNormalize.get("res");

        Map<String,Object> outputAfterNormalize = MatrixUtil.normalize(output, normalizationMin, normalizationMax);
        output = (Matrix) outputAfterNormalize.get("res");

        int times = 1;
        double E = 0;//误差
        while (times < maxTimes) {
            /*-----------------正向传播---------------------*/
            // 隐含层输入
            Matrix jIn = input.multiple(weightIJ);
            // 扩充阈值
            Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount());
            // 加上阈值
            jIn = jIn.plus(b1Copy);
            // 隐含层输出
            Matrix jOut = computeValue(jIn,activationFunction);
            // 输出层输入
            Matrix pIn = jOut.multiple(weightJP);
            // 扩充阈值
            Matrix b2Copy = b2.extend(2, pIn.getMatrixRowCount());
            // 加上阈值
            pIn = pIn.plus(b2Copy);
            // 输出层输出
            Matrix pOut = computeValue(pIn,activationFunction);
            // 计算误差
            Matrix e = output.subtract(pOut);
            E = computeE(e);//误差
            // 判断是否符合精度
            if (Math.abs(E) <= precision) {
                System.out.println("满足精度");
                break;
            }

            /*-----------------反向传播---------------------*/
            // J与P之间权值修正量
            Matrix deltaWeightJP = e.multiple(step);
            deltaWeightJP = deltaWeightJP.pointMultiple(computeDerivative(pIn,activationFunction));
            deltaWeightJP = deltaWeightJP.transpose().multiple(jOut);
            deltaWeightJP = deltaWeightJP.transpose();
            // P层神经元阈值修正量
            Matrix deltaThresholdP = e.multiple(step);
            deltaThresholdP = deltaThresholdP.transpose().multiple(computeDerivative(pIn, activationFunction));

            // I与J之间的权值修正量
            Matrix deltaO = e.pointMultiple(computeDerivative(pIn,activationFunction));
            Matrix tmp = weightJP.multiple(deltaO.transpose()).transpose();
            Matrix deltaWeightIJ = tmp.pointMultiple(computeDerivative(jIn, activationFunction));
            deltaWeightIJ = input.transpose().multiple(deltaWeightIJ);
            deltaWeightIJ = deltaWeightIJ.multiple(step);

            // J层神经元阈值修正量
            Matrix deltaThresholdJ = tmp.transpose().multiple(computeDerivative(jIn, activationFunction));
            deltaThresholdJ = deltaThresholdJ.multiple(-step);

            if (times == 1) {
                // 更新权值与阈值
                weightIJ = weightIJ.plus(deltaWeightIJ);
                weightJP = weightJP.plus(deltaWeightJP);
                b1 = b1.plus(deltaThresholdJ);
                b2 = b2.plus(deltaThresholdP);
            }else{
                // 加动量项
                weightIJ = weightIJ.plus(deltaWeightIJ).plus(deltaWeightIJ0.multiple(momentumFactor));
                weightJP = weightJP.plus(deltaWeightJP).plus(deltaWeightJP0.multiple(momentumFactor));
                b1 = b1.plus(deltaThresholdJ).plus(deltaB10.multiple(momentumFactor));
                b2 = b2.plus(deltaThresholdP).plus(deltaB20.multiple(momentumFactor));
            }

            deltaWeightIJ0 = deltaWeightIJ;
            deltaWeightJP0 = deltaWeightJP;
            deltaB10 = deltaThresholdJ;
            deltaB20 = deltaThresholdP;

            times++;
        }

        // BP神经网络的输出
        BPModel result = new BPModel();
        result.setInputMax((Matrix) inputAfterNormalize.get("max"));
        result.setInputMin((Matrix) inputAfterNormalize.get("min"));
        result.setOutputMax((Matrix) outputAfterNormalize.get("max"));
        result.setOutputMin((Matrix) outputAfterNormalize.get("min"));
        result.setWeightIJ(weightIJ);
        result.setWeightJP(weightJP);
        result.setB1(b1);
        result.setB2(b2);
        result.setError(E);
        result.setTimes(times);
        result.setBpParameter(bpParameter);
        System.out.println("循环次数:" + times + ",误差:" + E);

        return result;
    }

    /**
     * 计算BP神经网络的值
     * @param bpModel
     * @param input
     * @return
     */
    public Matrix computeBP(BPModel bpModel,Matrix input) throws Exception {
        if (input.getMatrixColCount() != bpModel.getBpParameter().getInputLayerNeuronCount()) {
            throw new Exception("输入矩阵纬度有误");
        }
        ActivationFunction activationFunction = bpModel.getBpParameter().getActivationFunction();
        Matrix weightIJ = bpModel.getWeightIJ();
        Matrix weightJP = bpModel.getWeightJP();
        Matrix b1 = bpModel.getB1();
        Matrix b2 = bpModel.getB2();
        double[][] normalizedInput = new double[input.getMatrixRowCount()][input.getMatrixColCount()];
        for (int i = 0; i < input.getMatrixRowCount(); i++) {
            for (int j = 0; j < input.getMatrixColCount(); j++) {
                normalizedInput[i][j] = bpModel.getBpParameter().getNormalizationMin()
                        + (input.getValOfIdx(i,j) - bpModel.getInputMin().getValOfIdx(0,j))
                        / (bpModel.getInputMax().getValOfIdx(0,j) - bpModel.getInputMin().getValOfIdx(0,j))
                        * (bpModel.getBpParameter().getNormalizationMax() - bpModel.getBpParameter().getNormalizationMin());
            }
        }
        Matrix normalizedInputMatrix = new Matrix(normalizedInput);
        Matrix jIn = normalizedInputMatrix.multiple(weightIJ);
        // 扩充阈值
        Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount());
        // 加上阈值
        jIn = jIn.plus(b1Copy);
        // 隐含层输出
        Matrix jOut = computeValue(jIn,activationFunction);
        // 输出层输入
        Matrix pIn = jOut.multiple(weightJP);
        // 扩充阈值
        Matrix b2Copy = b2.extend(2,pIn.getMatrixRowCount());
        // 加上阈值
        pIn = pIn.plus(b2Copy);
        // 输出层输出
        Matrix pOut = computeValue(pIn,activationFunction);
        // 反归一化
        return MatrixUtil.inverseNormalize(pOut, bpModel.getBpParameter().getNormalizationMax(), bpModel.getBpParameter().getNormalizationMin(), bpModel.getOutputMax(), bpModel.getOutputMin());
    }

    // 初始化权值
    private Matrix initWeight(int x,int y){
        Random random=new Random();
        double[][] weight = new double[x][y];
        for (int i = 0; i < x; i++) {
            for (int j = 0; j < y; j++) {
                weight[i][j] = 2*random.nextDouble()-1;
            }
        }
        return new Matrix(weight);
    }
    // 初始化阈值
    private Matrix initThreshold(int x){
        Random random = new Random();
        double[][] result = new double[1][x];
        for (int i = 0; i < x; i++) {
            result[0][i] = 2*random.nextDouble()-1;
        }
        return new Matrix(result);
    }

    /**
     * 计算激活函数的值
     * @param a
     * @return
     */
    private Matrix computeValue(Matrix a, ActivationFunction activationFunction) throws Exception {
        if (a.getMatrix() == null) {
            throw new Exception("参数值为空");
        }
        double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
        for (int i = 0; i < a.getMatrixRowCount(); i++) {
            for (int j = 0; j < a.getMatrixColCount(); j++) {
                result[i][j] = activationFunction.computeValue(a.getValOfIdx(i,j));
            }
        }
        return new Matrix(result);
    }

    /**
     * 激活函数导数的值
     * @param a
     * @return
     */
    private Matrix computeDerivative(Matrix a , ActivationFunction activationFunction) throws Exception {
        if (a.getMatrix() == null) {
            throw new Exception("参数值为空");
        }
        double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
        for (int i = 0; i < a.getMatrixRowCount(); i++) {
            for (int j = 0; j < a.getMatrixColCount(); j++) {
                result[i][j] = activationFunction.computeDerivative(a.getValOfIdx(i,j));
            }
        }
        return new Matrix(result);
    }


    /**
     * 计算误差
     * @param e
     * @return
     */
    private double computeE(Matrix e){
        e = e.square();
        return 0.5*e.sumAll();
    }
}
BPNeuralNetworkFactory代码

 

使用方式

思路就是创建BPNeuralNetworkFactory对象,并传入BPParameter对象,调用BPNeuralNetworkFactory的trainBP(BPParameter bpParameter, Matrix inputAndOutput)方法,返回一个BPModel对象,可以使用BPNeuralNetworkFactory的序列化方法,将其序列化到本地,或者将其放到缓存中,使用时直接从本地反序列化获取到BPModel对象,调用BPNeuralNetworkFactory的computeBP(BPModel bpModel,Matrix input)方法,即可获取计算值。

使用详情请看:https://github.com/ineedahouse/top-algorithm-set-doc/blob/master/doc/bpnn/BPNeuralNetwork.md

源码github地址

https://github.com/ineedahouse/top-algorithm-set

对您有帮助的话,请点个Star~谢谢

 

参考:基于BP神经网络的无约束优化方法研究及应用[D]. 赵逸翔.东北农业大学 2019

posted @ 2020-07-29 11:30  MrZhaoyx  阅读(2955)  评论(6编辑  收藏  举报