徒手编写Spring的初始化之山寨版IOC容器

建一个简单的web工程。

工程目录:

  配置application.properties

scanPackage=com.gys.demo #扫描该包下的类

编写注解

package annotation;

import java.lang.annotation.*;

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface GysAutowired {
    String value() default "";
}

 

package annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface GysController {
    String value() default "";
}

 

package annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface GysRequestMapping {
    String value() default "";
}
package annotation;

import java.lang.annotation.*;

@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface GysRequestParam {
    String value() default "";
}

 

package annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface GysService {
    String value() default "";
}

service编写

package com.gys.demo.service;

public interface IDemoService {
    String get(String name);
}
package com.gys.demo.service.impl;


import annotation.GysService;
import com.gys.demo.service.IDemoService;

@GysService
public class DemoService implements IDemoService {

    @Override
    public String get(String name) {
        return "<h1>Hello,"+name+"</h1>";
    }
}

controller代码

package com.gys.demo.controller;

import annotation.GysAutowired;
import annotation.GysController;
import annotation.GysRequestMapping;
import annotation.GysRequestParam;
import com.gys.demo.service.IDemoService;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;

@GysController
@GysRequestMapping("/demo")
public class DemoController {

    @GysAutowired
    private IDemoService iDemoService;

    @GysRequestMapping("/query")
    public void query(HttpServletRequest request,HttpSession session,HttpServletResponse response, @GysRequestParam("name") String name) throws IOException {
        System.out.println("query..............");
        String res=iDemoService.get(name);
        response.getWriter().write(res);

    }
}

新建Servlet,配置servlet

<!DOCTYPE web-app PUBLIC
 "-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN"
 "http://java.sun.com/dtd/web-app_2_3.dtd" >

<web-app>
  <display-name>Archetype Created Web Application</display-name>
<servlet>
  <servlet-name>gysMvc</servlet-name>
  <servlet-class>servlet.GysDispatcherServlet</servlet-class>
  <init-param>
    <param-name>contextConfigLocation</param-name>
    <param-value>application.properties</param-value>
  </init-param>
  <load-on-startup>1</load-on-startup>
</servlet>
  <servlet-mapping>
    <servlet-name>gysMvc</servlet-name>
    <url-pattern>/*</url-pattern>

  </servlet-mapping>
</web-app>

核心代码servlet,代码有点长,直接折叠了。

package servlet;

import annotation.*;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class GysDispatcherServlet extends HttpServlet {

    private Properties contextConfig = new Properties();
    //包+类文件名(去除.class后缀)
    private List<String> classNames = new ArrayList<>();
    //名称,对象
    private Map<String, Object> ioc = new HashMap<>();
    //url,方法
    private List<Handler> handlerMapping=new ArrayList<>();


    @Override
    public void init(ServletConfig config) throws ServletException {

        //1.加载配置文件
        doLoadConfig(config.getInitParameter("contextConfigLocation"));
        //2.扫描相关类
        doScanner(contextConfig.getProperty("scanPackage"));
        //3.初始化扫描的类,并放入ioc容器
        doInstance();
        //完成依赖注入
        doAutowired();

        //url和method的一对一关系
        initHandlerMapping();
        System.out.println("servlet init finsh==================");
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        try {
            req.setCharacterEncoding("utf-8");
            resp.setCharacterEncoding("utf-8");
            resp.setContentType("text/html;charset=utf-8");
            //具体处理逻辑
            doDispatch(req, resp);
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            resp.getWriter().write("500 Exception" + Arrays.toString(e.getStackTrace()));
        }
    }

    private void doDispatch(HttpServletRequest request, HttpServletResponse response) throws IOException, InvocationTargetException, IllegalAccessException {
        Handler handler=getHandler(request);
        if(handler==null){
            response.getWriter().write("404 对不起没有您要的页面资源");
            return;
        }
        //参数类型数组
        Class<?>[] parameterTypes=handler.method.getParameterTypes();
        //参数数组
        Object[] paramValues=new Object[parameterTypes.length];
        Map<String, String[]> parameterMap = request.getParameterMap();
        for(Map.Entry<String,String[]> entry:parameterMap.entrySet()){
            //数组中的[[ 和 ]] 替换
            String value=Arrays.toString(entry.getValue()).replaceAll("\\[|\\]","");
            if(!handler.paramIndexMapping.containsKey(entry.getKey())){
                continue;
            }
            int index=handler.paramIndexMapping.get(entry.getKey());
            value=new String(value.getBytes("iso-8859-1"), "utf-8");
            paramValues[index]=convert(parameterTypes[index],value);
        }
        //如果方法参数是request
        if(handler.paramIndexMapping.containsKey(HttpServletRequest.class.getName())){
            int reqIndex=handler.paramIndexMapping.get(HttpServletRequest.class.getName());
            paramValues[reqIndex]=request;
        }
        //如果方法参数是response
        if(handler.paramIndexMapping.containsKey(HttpServletResponse.class.getName())){
            int reqIndex=handler.paramIndexMapping.get(HttpServletResponse.class.getName());
            paramValues[reqIndex]=response;
        }
        //如果方法参数是session
        if(handler.paramIndexMapping.containsKey(HttpSession.class.getName())){
            int reqIndex=handler.paramIndexMapping.get(HttpSession.class.getName());
            paramValues[reqIndex]=request.getSession();
        }
        //利用反射调用mapping标识的方法
        Object returnValue=handler.method.invoke(handler.controller,paramValues);
        if(returnValue==null||returnValue instanceof Void){
            return;
        }
        //向浏览器输出类容
        response.getWriter().write(returnValue.toString());
    }

    private void doLoadConfig(String contextConfigLocation) {
        InputStream is = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
        try {
            contextConfig.load(is);
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (is != null) {
                try {
                    is.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }

    }

    private void doScanner(String scanPackage) {
        String xgScanPackage = scanPackage.replaceAll("\\.", "/");
        URL url = this.getClass().getClassLoader().getResource("/" + xgScanPackage);
        String path = url.getFile();
        File classDir = new File(path);
        for (File file : classDir.listFiles()) {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else {
                if (!file.getName().endsWith(".class")) {
                    continue;
                }
                //所有类的文件路径+包名
                String clazzName = (scanPackage + "." + file.getName()).replace(".class", "");
                classNames.add(clazzName);
            }
        }
    }

    private void doInstance() {
        if (classNames.isEmpty()) {
            return;
        }
        try {
            for (String className : classNames) {
                Class<?> clazz = Class.forName(className);
                //controller注解
                if(clazz.isAnnotationPresent(GysController.class)){
                    Object instance=clazz.newInstance();
                    String beanName=toLowerFirstCase(clazz.getSimpleName());
                    //类名,实例对象
                    ioc.put(beanName,instance);
                }else if(clazz.isAnnotationPresent(GysService.class)){//service注解
                    Object instance=clazz.newInstance();
                    GysService gysService=clazz.getAnnotation(GysService.class);
                    String beanName=gysService.value();
                    //没规定类名
                    if("".equals(beanName.trim())){
                        //类名首字母小写
                        beanName=toLowerFirstCase(clazz.getSimpleName());
                    }
                    //类名,对象
                    ioc.put(beanName,instance);
                    //循环接口
                    for(Class inter:clazz.getInterfaces()){
                        //接口长类名,子类实现对象
                        ioc.put(inter.getName(),instance);
                    }
                }else{
                    continue;
                }
            }
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        }
    }


    private void  doAutowired(){
        if(ioc.isEmpty()){
            return;
        }

        //循环ioc容器
        for (Map.Entry<String,Object> entry:ioc.entrySet()) {
            //获取所有的字段
            Field[] fields = entry.getValue().getClass().getDeclaredFields();
            for(Field field:fields){
                field.setAccessible(true);
                //判断是否有有依赖注入
                if (!field.isAnnotationPresent(GysAutowired.class)) {
                    continue;
                }
                GysAutowired gysAutowired = field.getAnnotation(GysAutowired.class);
                //获取依赖名称
                String beanName=gysAutowired.value().trim();
                if (beanName.isEmpty()) {//未定义依赖名称
                    //获取字段类型长路径名
                    Class type= field.getType();
                    if (type.isInterface()) {//接口用长路径名
                        beanName=type.getName();
                    }else{//实体类用类名
                        beanName=toLowerFirstCase(type.getSimpleName());
                    }
                }
                try {
                    //设置字段值,实现依赖注入
                    field.set(entry.getValue(),ioc.get(beanName));
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        }

    }

    private void initHandlerMapping(){
        if(ioc.isEmpty()){
            return;
        }
        for(Map.Entry<String,Object> entry:ioc.entrySet()){
            Class<?> clazz=entry.getValue().getClass();
            if(!clazz.isAnnotationPresent(GysController.class)){
                continue;
            }
            String url="";
            if(clazz.isAnnotationPresent(GysRequestMapping.class)){
                GysRequestMapping gysRequestMapping=clazz.getAnnotation(GysRequestMapping.class);
                url=gysRequestMapping.value();
            }
            for(Method method:clazz.getMethods()){
                if(!method.isAnnotationPresent(GysRequestMapping.class)){
                    continue;
                }
                GysRequestMapping gysRequestMapping=method.getAnnotation(GysRequestMapping.class);
                String regex=url+gysRequestMapping.value();
                Pattern pattern=Pattern.compile(regex);
                handlerMapping.add(new Handler(pattern,entry.getValue(),method));
                System.out.println("Mapped:"+url+","+method);
            }
        }
    }

    //首字母小写
    private String toLowerFirstCase(String simpleName){
        char[] chars=simpleName.toCharArray();
        chars[0]+=32;
        return String.valueOf(chars);
    }

    private class Handler{
        protected Object controller;//保存方法对应的实例
        protected  Method method;//保存映射方法
        protected Pattern pattern;
        protected Map<String,Integer> paramIndexMapping;//参数顺序

        public Handler(Pattern pattern,Object controller, Method method) {
            this.controller = controller;
            this.method = method;
            this.pattern = pattern;
            this.paramIndexMapping = new HashMap<>();
            putParamIndexMapping(this.method);

        }

        private void putParamIndexMapping(Method method){
            Annotation[][] annotations=method.getParameterAnnotations();
            for(int i=0;i<annotations.length;i++){
                for(Annotation annotation:annotations[i]){
                    if(annotation instanceof GysRequestParam){
                        String paranName=((GysRequestParam) annotation).value();
                        if(!paranName.trim().isEmpty()){
                            paramIndexMapping.put(paranName,i);
                        }
                    }
                }
            }
            Class<?>[] paramsTypes=method.getParameterTypes();
            for(int i=0;i<paramsTypes.length;i++){
                Class<?> type=paramsTypes[i];
                if(type==HttpServletRequest.class||type==HttpServletResponse.class||type==HttpSession.class){
                    paramIndexMapping.put(type.getName(),i);
                }
            }

        }
    }

    private Handler getHandler(HttpServletRequest request){
        if(handlerMapping.isEmpty()){
            return null;
        }
        String url=request.getRequestURI();
        String contextPath=request.getContextPath();
        url=url.replace(contextPath,"");
        for(Handler handler:handlerMapping){
            Matcher matcher=handler.pattern.matcher(url);
            if(!matcher.matches()){
                continue;
            }
            return handler;
        }
        return null;
    }

    //由于HTTP基于字符串协议,url传过来的参数都是String类型的;
    private Object convert(Class<?> type,String value){
        if(Integer.class==type){
            return  Integer.valueOf(value);
        }else if(Double.class==type){
            return Double.valueOf(value);
        }
        //.......
        return value;
    }

}
View Code

 

运行项目

 

posted @ 2020-04-24 15:10  思思博士  阅读(284)  评论(0编辑  收藏  举报