数据库反向生成(MySQL+Data Jdbc)

之前说的研究Data Jdbc的反向生成,花了点时间把我之前Mybatis使用的生成类改造了下,代码如下:

import java.io.*;
import java.sql.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 自动生成实体类、实体映射XML文件、Mapper
 * 运行当前类即可在D:/workspace 生成文件
 * @Author:nxj
 * @Date:2024/5/23 13:23
 */
public class DataJdbcCreatorUtil {

    //package路径
    private final String prefix="com.example.jdbc";

    public static void main(String[] args) {
        try {
            new DataJdbcCreatorUtil().generate();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (SQLException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }


    /**
     * 获取所有的表
     *
     * @return
     * @throws SQLException
     */
    private List<String> getTables() throws SQLException {
        List<String> tables = new ArrayList<String>();
        tables.add("sys_user");
        return tables;
    }

    /**
     * ********************************* 使用前必读*******************
     * *
     * *
     * **********************************************************
     */

    private final String type_char = "char";

    private final String type_date = "date";

    private final String type_timestamp = "timestamp";

    private final String type_int = "int";

    private final String type_bigint = "bigint";

    private final String type_tinyint = "tinyint";
    private final String type_varchar = "varchar";
    private final String type_longtext = "longtext";
    private final String type_smallint = "smallint";
    private final String type_datetime = "datetime";
    private final String type_binary = "binary";
    private final String type_time = "time";

    private final String type_text = "text";

    private final String type_bit = "bit";

    private final String type_double = "double";

    private final String type_float = "float";

    private final String type_decimal = "decimal";

    private final String type_blob = "blob";


    /**
     * 路径
     */

    private final String bean_path = "d:/Workspace/model/";

    private final String mapper_path = "d:/Workspace/mapper";

    private final String xml_path = "d:/Workspace/xml/";

    private final String service_path = "d:/Workspace/service/";

    private final String service_impl_path = "d:/Workspace/service/impl";

    /**
     * package路径
     */
    private final String bean_package = prefix+".entity";

    private final String mapper_package = prefix+".mapper";

    private final String service_package = prefix+".service";

    private final String service_impl_package = prefix+".service.impl";

    /**
     * 数据库
     */

    private final String url = "jdbc:mysql://你的MySQL连接";

    private final String driverName = "com.mysql.cj.jdbc.Driver";

    private final String user = "你的账户";

    private final String password = "你的密码";


    private String tableName = null;

    private String beanName = null;

    private String mapperName = null;

    private String _mapperName = null;

    private String serviceName = null;

    private String serviceImplName = null;

    private String primaryKeyName = null;

    private Connection conn = null;

    private void init() throws ClassNotFoundException, SQLException {
        Class.forName(driverName);
        conn = DriverManager.getConnection(url, user, password);
    }

    private void processTable(String table) {
        StringBuffer sb = new StringBuffer(table.length());
        String tableNew = table.toLowerCase();
        String[] tables = tableNew.split("_");
        String temp = null;
        for (int i = 0; i < tables.length; i++) {
            temp = tables[i].trim();
            sb.append(temp.substring(0, 1).toUpperCase()).append(temp.substring(1));
        }
        beanName = sb.toString();
        mapperName = "Jdbc"+beanName + "Repository";
        _mapperName = shotFirst(mapperName);
        serviceName = beanName + "Service";
        serviceImplName = serviceName + "Impl";
    }

    /**
     * 首字母转小写
     *
     * @return
     */
    private String shotFirst(String str) {
        char[] chars = new char[1];
        chars[0] = str.charAt(0);
        String temp = new String(chars);
        if (chars[0] >= 'A' && chars[0] <= 'Z') {
            str = str.replaceFirst(temp, temp.toLowerCase());
        }
        return str;
    }

    private String processType(String type) {
        if (type.indexOf(type_char) > -1) {
            return "java.lang.String";
        } else if (type.indexOf(type_bigint) > -1) {
            return "java.lang.Long";
        } else if (type.indexOf(type_int) > -1) {
            return "java.lang.Integer";
        } else if (type.indexOf(type_date) > -1) {
            return "java.util.Date";
        } else if (type.indexOf(type_text) > -1) {
            return "java.lang.String";
        } else if (type.indexOf(type_timestamp) > -1) {
            return "java.util.Date";
        } else if (type.indexOf(type_bit) > -1) {
            return "java.lang.Boolean";
        } else if (type.indexOf(type_decimal) > -1) {
            return "java.math.BigDecimal";
        } else if (type.indexOf(type_blob) > -1) {
            return "byte[]";
        } else if (type.indexOf(type_double) > -1) {
            return "java.lang.Double";
        } else if (type.indexOf(type_float) > -1) {
            return "java.lang.Float";
        } else if (type.indexOf(type_tinyint) > -1) {
            return "java.lang.Integer";
        } else if (type.indexOf(type_smallint) > -1) {
            return "java.lang.Integer";
        } else if (type.indexOf(type_varchar) > -1) {
            return "java.lang.String";
        } else if (type.indexOf(type_longtext) > -1) {
            return "java.lang.String";
        } else if (type.indexOf(type_datetime) > -1) {
            return "java.util.Date";
        } else if (type.indexOf(type_binary) > -1) {
            return "java.lang.Long";
        } else if (type.indexOf(type_time) > -1) {
            return "java.util.Date";
        }


        return null;
    }

    private String processField(String field) {
        StringBuffer sb = new StringBuffer(field.length());
        String[] fields = field.split("_");
        String temp = null;
        sb.append(fields[0]);
        for (int i = 1; i < fields.length; i++) {
            temp = fields[i].trim();
            sb.append(temp.substring(0, 1).toUpperCase()).append(
                    temp.substring(1));
        }
        return sb.toString();
    }

    /**
     * 构建类上面的注释
     */
    private BufferedWriter buildClassComment(BufferedWriter bw, String text)
            throws IOException {
        bw.newLine();
        bw.newLine();
        bw.write("/**");
        bw.newLine();
        bw.write(" * ");
        bw.newLine();
        bw.write(" * @author xx " + text);
        bw.newLine();
        bw.write(" * ");
        bw.newLine();
        bw.write(" **/");
        return bw;
    }

    /**
     * 生成实体类
     *
     * @param columns
     * @param types
     * @param comments
     * @throws IOException
     */
    private void buildEntityBean(List<String> columns, List<String> types, List<String> comments, String tableComment) throws IOException {
        File folder = new File(bean_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }

        File beanFile = new File(bean_path, beanName + ".java");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(beanFile)));
        bw.write("package " + bean_package + ";\n");
        bw.newLine();
        bw.write("import lombok.Data;");
        bw.newLine();
        bw.write("import org.springframework.data.relational.core.mapping.Table;");
        bw.newLine();
        bw.write("import java.io.Serial;");
        bw.newLine();
        bw.write("import java.io.Serializable;");
        bw = buildClassComment(bw, beanName + ".java");
        bw.newLine();
        bw.write("@Data");
        bw.newLine();
        bw.write("@Table(\""+tableName+"\")");
        bw.newLine();
        bw.write("public class " + beanName+" implements Serializable" + " {");
        bw.newLine();
        bw.write("@Serial");
        bw.newLine();
        bw.write("private static final long serialVersionUID = 1L;");
        bw.newLine();
        int size = columns.size();
        for (int i = 0; i < size; i++) {
            bw.write("\t/**" + comments.get(i) + "**/");
            bw.newLine();
            bw.write("\tprivate " + processType(types.get(i)) + " "
                    + processField(columns.get(i)) + ";");
            bw.newLine();
            bw.newLine();
        }
        bw.newLine();
        String tempField = null;
        String _tempField = null;
        String tempType = null;
        for (int i = 0; i < size; i++) {
            tempType = processType(types.get(i));
            _tempField = processField(columns.get(i));
            if (i == 0) {
                primaryKeyName = _tempField;
            }
        }
        bw.write("}");
        bw.newLine();
        bw.flush();
        bw.close();
    }

    /**
     * 构建Service文件
     *
     * @throws IOException
     */
    private void buildService() throws IOException {
        File folder = new File(service_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }

        File serviceFile = new File(service_path, serviceName + ".java");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(serviceFile), "utf-8"));
        bw.write("package " + service_package + ";");
        bw.newLine();
        bw.newLine();
        bw.write("import " + this.bean_package + "." + beanName + ";");
        bw.newLine();
        bw = buildClassComment(bw, serviceName);
        bw.newLine();
        bw.write("public interface " + serviceName + " {");
        bw.newLine();
        bw.newLine();
        bw.newLine();
        bw.newLine();
        bw.write("}");
        bw.flush();
        bw.close();
    }

    /**
     * 构建ServiceImpl文件
     *
     * @throws IOException
     */
    private void buildServiceImpl() throws IOException {
        File folder = new File(service_impl_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }
        File daoFile = new File(service_impl_path, serviceImplName + ".java");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(daoFile), "utf-8"));
        bw.write("package " + service_impl_package + ";");
        bw.newLine();
        bw.newLine();
        bw.write("import org.springframework.beans.factory.annotation.Autowired;");
        bw.newLine();
        bw.write("import " + mapper_package + "." + mapperName + ";");
        bw.newLine();
        bw.write("import org.springframework.stereotype.Service;");
        bw.newLine();
        bw.write("import " + service_package + "." + serviceName + ";");
        bw.newLine();
        bw.write("import " + bean_package + "." + beanName + ";");
        bw.newLine();
        bw = buildClassComment(bw, serviceName);
        bw.newLine();
        bw.write("@Service(\"" + this.shotFirst(serviceName) + "\")");
        bw.newLine();
        bw.write("public class " + serviceImplName + " implements " + serviceName + " {");
        bw.newLine();
        bw.newLine();
        bw.write("\t@Autowired");
        bw.newLine();
        bw.write("\tprivate " + mapperName + " " + this.shotFirst(mapperName) + ";");
        bw.newLine();
        bw.newLine();
        bw.write("}");
        bw.flush();
        bw.close();
    }

    /**
     * 构建Mapper文件
     *
     * @throws IOException
     */
    private void buildMapper() throws IOException {
        File folder = new File(mapper_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }

        File mapperFile = new File(mapper_path, mapperName + ".java");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(
                new FileOutputStream(mapperFile), "utf-8"));
        bw.write("package " + mapper_package + ";");
        bw.newLine();
        bw.write("import " + bean_package + "." + beanName + ";");
        bw.newLine();
        bw.write("import org.springframework.data.repository.CrudRepository;");
        bw.newLine();
        bw.write("import org.springframework.stereotype.Repository;");
        bw.newLine();
        bw = buildClassComment(bw, mapperName + "数据库操作接口类");
        bw.newLine();
        bw.write("@Repository");
        bw.newLine();
        bw.write("public interface " + mapperName + " extends CrudRepository<" + beanName + ",Long> {");
        bw.newLine();
        bw.newLine();
        bw.write("}");
        bw.flush();
        bw.close();
    }

    /**
     * 获取所有的数据库表注释
     */
    private Map<String, String> getTableComment() throws SQLException {
        Map<String, String> maps = new HashMap<String, String>();
        PreparedStatement pstate = conn.prepareStatement("show table status");
        ResultSet results = pstate.executeQuery();
        while (results.next()) {
            String tableName = results.getString("NAME");
            String comment = results.getString("COMMENT");
            maps.put(tableName, comment);
        }
        return maps;
    }

    public void generate() throws ClassNotFoundException, SQLException,
            IOException {
        init();
        String prefix = "show full fields from ";
        List<String> columns = null;
        List<String> types = null;
        List<String> comments = null;
        PreparedStatement pstate = null;
        List<String> tables = getTables();
        Map<String, String> tableComments = getTableComment();

        for (String table : tables) {
            System.out.println("======>>>" + table);
            if (!table.equalsIgnoreCase("test")) {
                columns = new ArrayList<String>();
                types = new ArrayList<String>();
                comments = new ArrayList<String>();
                pstate = conn.prepareStatement(prefix + table);
                ResultSet results = pstate.executeQuery();
                while (results.next()) {
                    columns.add(results.getString("FIELD"));
                    types.add(results.getString("TYPE"));
                    comments.add(results.getString("COMMENT"));
                }
                tableName = table;
                if (table.startsWith("t_") || table.equalsIgnoreCase("user") || table.equalsIgnoreCase("support_gen")) {
                    System.out.println("==skip table====>>>" + table);
                    continue;
                }

                processTable(table);
                String tableComment = tableComments.get(tableName);
                buildEntityBean(columns, types, comments, tableComment);
                buildMapper();
                buildService();
                buildServiceImpl();
            }
        }
        conn.close();
    }
}

 注意:转载请标明 https://www.cnblogs.com/nxjblog/p/18217361

posted @ 2024-05-28 10:31  轻寒  阅读(17)  评论(0编辑  收藏  举报