手写mybaits(非Mapper xml版本)

 MyBaits部分的自定义注解代码

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
public @interface Param {

    String value();
}
package com.hz.mybatis.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Select {

    String value() default "";
}

 MyBaits部分的类型处理类代码

package com.hz.mybatis.handler;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

public class IntegerTypeHandler implements TypeHandler<Integer>{
    @Override
    public void setParameter(PreparedStatement ps, int i, Integer parameter) throws SQLException {
        ps.setInt(i, parameter);
    }

    @Override
    public Integer getResult(ResultSet rs, String columnName) throws SQLException {
        return rs.getInt(columnName);
    }
}
package com.hz.mybatis.handler;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

public class StringTypeHandler implements TypeHandler<String>{
    @Override
    public void setParameter(PreparedStatement ps, int i, String parameter) throws SQLException {
        ps.setString(i, parameter);
    }

    @Override
    public String getResult(ResultSet rs, String columnName) throws SQLException {
        return rs.getString(columnName);
    }
}
package com.hz.mybatis.handler;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

public interface TypeHandler<T> {

    void setParameter(PreparedStatement ps, int i, T parameter) throws SQLException;

    T getResult(ResultSet rs, String columnName) throws SQLException;
}

MyBaits部分的SQL分析器类代码

/**
 *    Copyright 2009-2017 the original author or authors.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */
package com.hz.mybatis.parser;

/**
 * @author Clinton Begin
 */
public class GenericTokenParser {

  private final String openToken;
  private final String closeToken;
  private final TokenHandler handler;

  public GenericTokenParser(String openToken, String closeToken, TokenHandler handler) {
    this.openToken = openToken;
    this.closeToken = closeToken;
    this.handler = handler;
  }

  public String parse(String text) {
    if (text == null || text.isEmpty()) {
      return "";
    }
    // search open token
    int start = text.indexOf(openToken, 0);
    if (start == -1) {
      return text;
    }
    char[] src = text.toCharArray();
    int offset = 0;
    final StringBuilder builder = new StringBuilder();
    StringBuilder expression = null;
    while (start > -1) {
      if (start > 0 && src[start - 1] == '\\') {
        // this open token is escaped. remove the backslash and continue.
        builder.append(src, offset, start - offset - 1).append(openToken);
        offset = start + openToken.length();
      } else {
        // found open token. let's search close token.
        if (expression == null) {
          expression = new StringBuilder();
        } else {
          expression.setLength(0);
        }
        builder.append(src, offset, start - offset);
        offset = start + openToken.length();
        int end = text.indexOf(closeToken, offset);
        while (end > -1) {
          if (end > offset && src[end - 1] == '\\') {
            // this close token is escaped. remove the backslash and continue.
            expression.append(src, offset, end - offset - 1).append(closeToken);
            offset = end + closeToken.length();
            end = text.indexOf(closeToken, offset);
          } else {
            expression.append(src, offset, end - offset);
            offset = end + closeToken.length();
            break;
          }
        }
        if (end == -1) {
          // close token was not found.
          builder.append(src, start, src.length - start);
          offset = src.length;
        } else {
          builder.append(handler.handleToken(expression.toString()));
          offset = end + closeToken.length();
        }
      }
      start = text.indexOf(openToken, offset);
    }
    if (offset < src.length) {
      builder.append(src, offset, src.length - offset);
    }
    return builder.toString();
  }
}
package com.hz.mybatis.parser;

public class ParameterMapping {

    private String property;

    public ParameterMapping(String property) {
        this.property = property;
    }

    public String getProperty() {
        return property;
    }

    public void setProperty(String property) {
        this.property = property;
    }
}
package com.hz.mybatis.parser;

import java.util.ArrayList;
import java.util.List;

public class ParameterMappingTokenHandler implements TokenHandler{

    private List<ParameterMapping> parameterMappings = new ArrayList<ParameterMapping>();

    @Override
    public String handleToken(String content) {
        parameterMappings.add(new ParameterMapping(content));
        return "?";
    }

    public List<ParameterMapping> getParameterMappings() {
        return parameterMappings;
    }
}
package com.hz.mybatis.parser;

public interface TokenHandler {
    String handleToken(String content);
}

MyBaits部分的动态代理调用JDBC代码(核心代码)

package com.hz.mybatis.mapper;

import com.hz.mybatis.annotation.Param;
import com.hz.mybatis.annotation.Select;
import com.hz.mybatis.handler.IntegerTypeHandler;
import com.hz.mybatis.handler.StringTypeHandler;
import com.hz.mybatis.handler.TypeHandler;
import com.hz.mybatis.parser.GenericTokenParser;
import com.hz.mybatis.parser.ParameterMapping;
import com.hz.mybatis.parser.ParameterMappingTokenHandler;
import java.lang.reflect.*;
import java.sql.*;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class MapperProxyFactory {

    private static Map<Class, TypeHandler> providers = new ConcurrentHashMap<>();

    static {
        try {
            Class.forName("com.mysql.jdbc.Driver");
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        providers.put(String.class, new StringTypeHandler());
        providers.put(Integer.class, new IntegerTypeHandler());
    }

    public static <T> T getMapper(Class<T> mapper){
        Object proxyInstance = Proxy.newProxyInstance(ClassLoader.getSystemClassLoader(), new Class[]{mapper}, new InvocationHandler() {
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                //目前只实现了Select注解效果
                boolean annotationPresent = method.isAnnotationPresent(Select.class);
                if(annotationPresent){
                    Select annotation = method.getAnnotation(Select.class);
                    //mybatis中写法:select * from t_user where name = #{userName} and age = #{age}
                    String sqlStr = annotation.value();
                    if(sqlStr == null ||sqlStr.equals("")){
                        throw new RuntimeException("注释中没有写入SQL");
                    }
                    // 通过Parser转义 预编译SQL ...where name = ? and age = ?
                    ParameterMappingTokenHandler tokenHandler = new ParameterMappingTokenHandler();
                    GenericTokenParser genericTokenParser = new GenericTokenParser("#{","}",tokenHandler);
                    String parseSQL = genericTokenParser.parse(sqlStr);

                    //JDBC start=============================================
                    //获取JDBC连接
                    Connection conn= DriverManager.getConnection("jdbc:mysql://localhost:3306/hs?serverTimezone=UTC&useUnicode=true&characterEncoding=utf-8&AllowPublicKeyRetrieval=True","root","root");

                    //获取被Select修饰的方法参数 封装 paramsMap<paramName,val>
                    Map<String,Object> paramsMap = new HashMap<>();
                    Parameter[] parameters = method.getParameters();
                    for (int i = 0; i < parameters.length; i++) {
                        Parameter parameter = parameters[i];
                        boolean annotationPresent1 = parameter.isAnnotationPresent(Param.class);
                        if(!annotationPresent1)throw new Exception("参数必须修饰");
                        Param annotation1 = parameter.getAnnotation(Param.class);
                        String value = annotation1.value();
                        String name = parameter.getName();
                        paramsMap.put(value == null || value.equals("") ? name : value,args[i]);
                    }
                    //预查询SQL 实现PreparedStatement.setString(index,val) or serInt.setString(index,val)
                    PreparedStatement stmt = conn.prepareStatement(parseSQL);
                    List<ParameterMapping> parameterMappings = tokenHandler.getParameterMappings();
                    for (int i = 0; i < parameterMappings.size(); i++) {
                        ParameterMapping parameterMapping = parameterMappings.get(i);
                        String property = parameterMapping.getProperty();
                        Object value = paramsMap.get(property);
                        Class<?> aClass = value.getClass();

                        providers.get(aClass).setParameter(stmt,i+1,value);
                    }
                    //执行,得到结果集
                    stmt.execute();
                    ResultSet rs = stmt.getResultSet();

                    List<Object> userList = new ArrayList<>();
                    Object object = null;
                    // 获取返回值
                    Class resultType = null;
                    Type genericReturnType = method.getGenericReturnType();
                    if(genericReturnType instanceof Class){
                        resultType = (Class)genericReturnType;
                    }else if(genericReturnType instanceof ParameterizedType){
                        Type actualTypeArgument = ((ParameterizedType) genericReturnType).getActualTypeArguments()[0];
                        resultType = (Class)actualTypeArgument;
                    }
                    //获取返回实体  <属性名,方法>
                    Map<String , Method> methodMap = new HashMap<>();
                    Method[] declaredMethods = resultType.getDeclaredMethods();
                    for (int i = 0; i < declaredMethods.length; i++) {
                        Method declaredMethod = declaredMethods[i];
                        String name = declaredMethod.getName();
                        if (name.startsWith("set")) {
                            String propertyName = declaredMethod.getName().substring(3);
                            propertyName = propertyName.substring(0, 1).toLowerCase(Locale.ROOT) + propertyName.substring(1);
                            methodMap.put(propertyName,declaredMethod);
                        }
                    }
                    //获取返回的字段名
                    ResultSetMetaData metaData = rs.getMetaData();
                    List<String> columnNames = new ArrayList<>();
                    for (int i = 0; i < metaData.getColumnCount(); i++) {
                        columnNames.add(metaData.getColumnName(i+1));
                    }

                    /**
                     *1.第一种方法
                     */
                    while (rs.next()){
                        Object o = resultType.newInstance();

//                        int id = rs.getInt(1);
//                        String name = rs.getString(2);
//                        int password =rs.getInt(3);
//                        userList.add(new User(id,name,password));
                        //上方注释代码的实现
                        for (int i = 0; i < columnNames.size(); i++) {
                            String columnName = columnNames.get(i);
                            Method method1 = methodMap.get(columnName);
                            Class<?> parameterType = method1.getParameterTypes()[0];
                            TypeHandler typeHandler = providers.get(parameterType);
                            Object result = typeHandler.getResult(rs, columnName);
                            method1.invoke(o,result);
                        }
                        userList.add(o);
                        if(method.getReturnType().equals(List.class)){
                            object = userList;
                        }else{
                            object = userList.get(0);
                        }
                    }
                    conn.close();
                    //JDBC end=============================================
                    return object;
                }else{
                    throw new RuntimeException("未被注解修饰。。");
                }
                //return method.invoke(age, args);
            }
        });

        return (T)proxyInstance;
    }
}

 

 

业务调用部分的代码

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.7.0</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.hz</groupId>
    <artifactId>HzMyBatis</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>HzMyBatis</name>
    <description>HzMyBatis</description>
    <properties>
        <java.version>1.8</java.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.35</version>
        </dependency>

        <!--数据库驱动-->
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>
package com.hz.mybatis.controller;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.hz.mybatis.pojo.User;
import com.hz.mybatis.service.UserService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.List;

@RestController
@RequestMapping("user")
public class UserController {

    @Autowired
    private UserService userService;

    @RequestMapping("listAll")
    public String listAll(@RequestParam("username") String username,@RequestParam("age") Integer age) {
        List<User> list = userService.listAll(username, age);
        String result = JSONArray.toJSONString(list);
        return result;
    }

    @RequestMapping("getById/{id}")
    public String getById(@PathVariable("id") Integer id) {
        User user = userService.getById(id);
        String result = JSONObject.toJSONString(user);
        return result;
    }
}
package com.hz.mybatis.pojo;

public class User {

    private Integer id;

    private String name;

    private Integer age;

    public User() {
    }

    public User(Integer id, String name, Integer age) {
        this.id = id;
        this.name = name;
        this.age = age;
    }

    public Integer getId() {
        return id;
    }

    public void setId(Integer id) {
        this.id = id;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public Integer getAge() {
        return age;
    }

    public void setAge(Integer age) {
        this.age = age;
    }
}
package com.hz.mybatis.service;

import com.hz.mybatis.pojo.User;

import java.util.List;

public interface UserService {

    List<User> listAll(String name, Integer age);

    User getById(Integer id);
}
package com.hz.mybatis.service.impl;

import com.hz.mybatis.mapper.MapperProxyFactory;
import com.hz.mybatis.mapper.UserMapper;
import com.hz.mybatis.pojo.User;
import com.hz.mybatis.service.UserService;
import org.springframework.stereotype.Service;
import java.util.List;

@Service
public class UserServiceImpl implements UserService {

    @Override
    public List<User> listAll(String name, Integer age) {
        UserMapper mapper = MapperProxyFactory.getMapper(UserMapper.class);
        Object o = mapper.selectAll(name, age);
        return (List<User>)o;
    }

    @Override
    public User getById(Integer id) {
        UserMapper mapper = MapperProxyFactory.getMapper(UserMapper.class);
        Object o = mapper.getById(id);
        return (User)o;
    }
}
package com.hz.mybatis.mapper;

import com.hz.mybatis.annotation.Param;
import com.hz.mybatis.annotation.Select;
import com.hz.mybatis.pojo.User;
import org.springframework.stereotype.Component;

import java.util.List;

public interface UserMapper {

    @Select("select * from t_user where name = #{userName} and age = #{age}")
    List<User> selectAll(@Param("userName") String name, @Param("age") Integer age);

    @Select("select * from t_user where id = #{id}")
    User getById(@Param("id") Integer id);

}

 

posted @ 2023-05-11 16:09  蔡徐坤1987  阅读(23)  评论(0编辑  收藏  举报