构建支持多种数据库类型的代码自动生成工具

背景:

一般的业务代码中写来写去,无外乎是先建好model,然后针对这个model做些CRUD的操作。(主要针对单表的业务操作)针对于数据库dao、mapper等的代码自动生成已经有了mybatisGenerator这种工具,但是针对于controller、service这些我们现在的接口api一般遵循的是restful风格,因此这些也是有规则可循的。举例有个goodsInfo 的model,针对于他的操作,肯定有 单个查询、list查询、修改、删除等。而这些代码没必要复制粘贴来一遍,完全可以由工具自动生成,若有特殊业务场景重写即可。本工具就算解决这类问题的。

效果截图

运行生成示例结果:
blob.jpg
blob.jpg

表选择界面:
blob.jpg

思路:

代码自动生成说起来很神秘,其实无外乎两个方面:

  1. 从数据库拿到需要自动生成的代码对应表。
  2. 从表结构、字段名生成对应的mapper、model、及controller、service等

如何拿到需要自动生成的代码对应表

sqlservr、mysql、oracle等这些主流数据库中都存在系统表结构的表,存储的是所有用户自己建立表的名称、字段等,所以直接查询这些系统表即可罗列出所有业务表。然后做个可视化界面供用户选择即可。(这里做一下更新,我实际项目中没有用sql查询的方式,因为不同数据库对于系统表的存储方式各不相同,查询语句写的太蛋疼了,实际采用的是 conn.getMetaData() 的方式,采用元数据来拿到指定数据库中各种表结构信息)

如何自动生成代码

有了表结构、字段名等如何自动生成代码呢,这个时候就需要模板引擎了。简单来讲可以理解为把固定的地方写死,变化的地方按照规则替换。
可以用我们小时候写作文的例子来说明。我们(作文厉害的请自动忽略 “们” 😃)小时候写作文,一般是3段式,开头、结尾、和中间流水账。 开头一般是描写环境心情、中间讲述具体故事,结尾总结赞美。

今天天气不错,风和日丽的,我们早早就来到了学些,大家都很开心。(开头)
	小明,突然在地上捡到了一个钱包……(一顿思想斗争,最后交给了警察叔叔)
最后,这个故事告诉了我们……(结尾)

从上面的示例范文中是不是很熟悉,。基本上都是这个结构,中间基本可以随意替换,最后都能凑成一篇基本合格的小学作文。而我们现在要做的就是把一些表名称、字段名称当做需要填充的内容填充到指定的代码段中去。

具体实现

获取数据库表、字段等信息

好了,上面讲了一大堆废话(背景和思路,个人觉得还是有必要的),下面到具体实现中来。
获取数据库表结构(表、字段)信息关键代码如下


@Service
public class DbServiceImpl implements IDbService {

    private Logger logger = LoggerFactory.getLogger(this.getClass());

    @Value("${spring.datasource.driverClassName}")
    private String driverClassName;
    @Value("${spring.datasource.url}")
    private String url;
    @Value("${spring.datasource.username}")
    private String user;
    @Value("${spring.datasource.password}")
    private String pwd;

    @Override
    public List<TableEntity> getTables(String tableName) {
        List<TableEntity> tables = new ArrayList<>();
        try {
            Class.forName(driverClassName);// 动态加载mysql驱动
            Connection connection = DriverManager.getConnection(url, user, pwd);
            DatabaseMetaData metaData = connection.getMetaData();
            ResultSet resultSet = metaData.getTables(null, null, "%", new String[]{"TABLE"});
//            metaData.getTables("yjc", "", "%", new String[]{"TABLE"})
            while (resultSet.next()) {
                if (resultSet.getString("TABLE_NAME").contains(tableName)) {
                    TableEntity tmpTable = new TableEntity();
                    tmpTable.setTbName(resultSet.getString("TABLE_NAME"));
                    tmpTable.setComments(resultSet.getString("REMARKS"));
                    tmpTable.setCatalog(resultSet.getString("TABLE_CAT"));
                    tmpTable.setSchema(resultSet.getString("TABLE_SCHEM"));
                    tables.add(tmpTable);
                }
            }
        } catch (Exception e) {
            logger.error("获取数据库表列表失败", e);
        }
        return tables;
    }

    @Override
    public List<ColumnEntity> getColumns(String tableName) {
        List<ColumnEntity> columnEntityList = new ArrayList<>();
        try {
            Class.forName(driverClassName);// 动态加载mysql驱动
            Connection connection = DriverManager.getConnection(url, user, pwd);
            DatabaseMetaData metaData = connection.getMetaData();
            ResultSet resultSet = metaData.getColumns(null, null, tableName, "%");
            while (resultSet.next()) {
                ColumnEntity tmpColumnEntity = new ColumnEntity();
                tmpColumnEntity.setColumnName(resultSet.getString("COLUMN_NAME"));
                tmpColumnEntity.setBufferLength(resultSet.getInt("BUFFER_LENGTH"));
                tmpColumnEntity.setColumnSize(resultSet.getInt("COLUMN_SIZE"));
                tmpColumnEntity.setComments(resultSet.getString("REMARKS"));
                tmpColumnEntity.setDecimalDigits(resultSet.getInt("DECIMAL_DIGITS"));
                tmpColumnEntity.setDataType(resultSet.getInt("DATA_TYPE"));
                tmpColumnEntity.setTypeName(resultSet.getString("TYPE_NAME"));
                tmpColumnEntity.setIsNullAble(resultSet.getString("IS_NULLABLE"));
                tmpColumnEntity.setIsAutoIncrement(resultSet.getString("IS_AUTOINCREMENT"));
                columnEntityList.add(tmpColumnEntity);
            }
        } catch (Exception e) {
            logger.error("查询表的列发生异常", e);
        }
        return columnEntityList;
    }

    @Override
    public TableEntity getTableEntity(String tableName) {
        TableEntity tableEntity = new TableEntity();
        try {
            Class.forName(driverClassName);// 动态加载mysql驱动
            Connection connection = DriverManager.getConnection(url, user, pwd);
            DatabaseMetaData metaData = connection.getMetaData();
            ResultSet resultSet = metaData.getTables(null, null, tableName, new String[]{"TABLE"});
            while (resultSet.next()) {
                if(tableName.equals(resultSet.getString("TABLE_NAME")))
                {
                    tableEntity.setTbName(resultSet.getString("TABLE_NAME"));
                    tableEntity.setComments(resultSet.getString("REMARKS"));
                    tableEntity.setCatalog(resultSet.getString("TABLE_CAT"));
                    tableEntity.setSchema(resultSet.getString("TABLE_SCHEM"));
                    tableEntity.setPk(this.getPrimaryKeyColumnName(metaData,tableEntity.getCatalog(),tableEntity.getSchema(),tableEntity.getTbName()));
                }
            }
        } catch (Exception e) {
            logger.error("获取表对象失败", e);
        }
        return tableEntity;
    }

    private String getPrimaryKeyColumnName(DatabaseMetaData metaData,String catalog,String schema,String tableName)
    {
        String primaryKeyColumnName="";
        try {
            ResultSet resultSet = metaData.getPrimaryKeys(catalog, schema, tableName);
            while (resultSet.next())
            {
                primaryKeyColumnName=   resultSet.getString("COLUMN_NAME");
            }
        } catch (SQLException e) {
            logger.error("获取主键发生异常",e);
        }
        return primaryKeyColumnName;
    }
}

另外在使用 DatabaseMetaData获取表、列信息的时候,如

  DatabaseMetaData metaData = connection.getMetaData();
  ResultSet resultSet = metaData.getColumns(null, null, tableName, "%");
  DatabaseMetaData metaData = connection.getMetaData();
  ResultSet resultSet = metaData.getTables(null, null, "%", new String[]{"TABLE"});

获取表格信息、获取列信息都是返回的 ResultSet,这个ResultSet 有点蛋疼,需要按照字段来查询,或者指定索引顺序来获取想要的结果,对照关系如下面截图

blob.jpg

blob.jpg

使用模板生成代码

使用的是velocity引擎(当然也可以使用freemarker等,这个不重要)
模板代码示例如下:

package ${package}.${moduleName}.entity;

import com.baomidou.mybatisplus.annotations.TableId;
import com.baomidou.mybatisplus.annotations.TableName;

#if(${hasBigDecimal})
import java.math.BigDecimal;
#end
import java.io.Serializable;
import java.util.Date;

/**
 * ${comments}
 * 
 * @author ${author}
 * @email ${email}
 * @date ${datetime}
 */
@TableName("${tableName}")
public class ${className}Entity implements Serializable {
	private static final long serialVersionUID = 1L;

#foreach ($column in $columns)
	/**
	 * $column.comments
	 */
	#if($column.columnName == $pk.columnName)
@TableId
	#end
private $column.attrType $column.attrname;
#end

#foreach ($column in $columns)
	/**
	 * 设置:${column.comments}
	 */
	public void set${column.attrName}($column.attrType $column.attrname) {
		this.$column.attrname = $column.attrname;
	}
	/**
	 * 获取:${column.comments}
	 */
	public $column.attrType get${column.attrName}() {
		return $column.attrname;
	}
#end
}

package ${package}.${moduleName}.controller;

import java.util.Arrays;
import java.util.Map;

import org.apache.shiro.authz.annotation.RequiresPermissions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import ${package}.${moduleName}.entity.${className}Entity;
import ${package}.${moduleName}.service.${className}Service;
import ${mainPath}.common.utils.PageUtils;
import ${mainPath}.common.utils.R;



/**
 * ${comments}
 *
 * @author ${author}
 * @email ${email}
 * @date ${datetime}
 */
@RestController
@RequestMapping("${moduleName}/${pathName}")
public class ${className}Controller {
    @Autowired
    private ${className}Service ${classname}Service;

    /**
     * 列表
     */
    @GetMapping("/list")
    public  ResponseEntity<BaseResponse<Page<${className}>>>  list(@PageableDefault(value = 15, sort = { "${pk}" }, direction = Sort.Direction.DESC) Pageable pageable)
    {
        BaseResponse<Page<${className}> > baseResponse=new BaseResponse<>();
        Page<${className}> all = ${classname}Repository.findAll(pageable)};
        if(all!=null && !all.isEmpty())
    {
        return  new ResponseEntity<BaseResponse<Page<${className}>>>(BaseResponseFactory.success(all),HttpStatus.OK);
    }
        else
    {
        return  new ResponseEntity<>(HttpStatus.BAD_REQUEST);
    }
}


    /**
    * 单个查询
    */
    @GetMapping("/{${pk}}")
    @ApiOperation(value = "/{${pk}}", httpMethod = "GET", notes = "查询单个${className}信息}")
    public  ResponseEntity<BaseResponse<${className}>> info(@PathVariable long ${pk}) {
        Optional<${className}> optional = ${classname}Repository.findById(${pk});
        if(optional.isPresent())
        {
            return   new ResponseEntity<>(BaseResponseFactory.success(optional.get()), HttpStatus.OK);
        }else
        {
            return new ResponseEntity(HttpStatus.NOT_FOUND);
        }
    }


    /**
    * 保存
    */
    @PostMapping("/add")
    public ResponseEntity<BaseResponse<${className}>> save(@Validated @RequestBody ${className} goodsInfo, BindingResult bindingResult)
    {
        ResponseEntity<BaseResponse<${className}>> responseEntity;
        BaseResponse  baseResponse=new BaseResponse<>();
        if(bindingResult.hasErrors())
        {
            StringBuilder sb=new StringBuilder();
            for (FieldError fieldError : bindingResult.getFieldErrors()) {
                sb.append(fieldError.getDefaultMessage());
                sb.append(" ");
            }
            baseResponse.setCode(400);
            baseResponse.setMessage(sb.toString());
            responseEntity=new ResponseEntity<>(baseResponse,HttpStatus.BAD_REQUEST);
        }
        else
        {
            ${className} save = ${classname}Repository.save(${classname});
            baseResponse.setCode(200);
            baseResponse.setMessage("保存成功");
            baseResponse.setData(save);
            responseEntity=new ResponseEntity<>(baseResponse, HttpStatus.OK);
        }
        return  responseEntity;
    }


}

字段对应转换

如何将数据库的字段类型对应到java代码上,比如数据库中的varchar,需要对应到java的String,本例是参考了一个自动生成工具的方式,使用了对应配置表,内容如下。

#代码生成器,配置信息

mainPath=com.
#包名
package=redheart
moduleName=erp
#作者
author=pf
#Email
email=103868365@qq.com
#表前缀(类名不会包含表前缀)
tablePrefix=yjc_

#类型转换,配置信息
TINYINT=Integer
SMALLINT=Integer
MEDIUMINT=Integer
INT=Integer
INTEGER=Integer
BIGINT=Long
FLOAT=Float
DOUBLE=Double
DECIMAL=BigDecimal
BIT=Boolea
CHAR=String
VARCHAR=String
TINYTEXT=String
TEXT=String
MEDIUMTEXT=String
LONGTEXT=String
DATE=Date
DATETIME=Date
TIMESTAMP=Date
posted on 2019-06-20 17:46  falcon_fei  阅读(1483)  评论(0编辑  收藏  举报