手写一个简单版的SpringMVC
一 写在前面
这是自己实现一个简单的具有SpringMVC功能的小Demo,主要实现效果是;
自己定义的实现效果是通过浏览器地址传一个name参数,打印“my name is”+name参数。不使用SpringMVC,自己定义部分注解,实现DispatcherServlet核心功能,通过这个demo可以加深自己对源码的理解。
先看一下实现效果:
(传入了参数时)
(没有传入参数时)
二 DispatcherServlet流程
- 加载配置文件
- 扫描所有相关类
- 初始化所有相关的类
- 自动注入
- 初始化HandlerMapping
- 等待请求
三 代码回顾
1.首先来看一下Pom文件的依赖:
<dependencies> <dependency> <groupId>javax.servlet</groupId> <artifactId>servlet-api</artifactId> <version>2.5</version> </dependency> <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> <version>3.10</version> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <version>1.18.12</version> </dependency> <dependency> <groupId>ch.qos.logback</groupId> <artifactId>logback-core</artifactId> <version>1.2.3</version> </dependency> <dependency> <groupId>ch.qos.logback</groupId> <artifactId>logback-classic</artifactId> <version>1.2.3</version> </dependency> </dependencies>
依赖比较少,没有spring的依赖,主要就是一个servlet的。
2. 配置文件:
2.1. application.properties文件:
scanPackage=com.qunar.framework.demo
这是说明要扫描的位置。
2.2. web.xml文件:
<!DOCTYPE web-app PUBLIC "-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN" "http://java.sun.com/dtd/web-app_2_3.dtd" > <web-app> <display-name>MySpringMVC</display-name> <servlet> <servlet-name>mvc</servlet-name> <servlet-class>com.qunar.framework.webmvc.DispatcherServlet</servlet-class> <init-param> <param-name>contextConfigLocation</param-name> <param-value>/application.properties</param-value> </init-param> <load-on-startup>1</load-on-startup> </servlet> <servlet-mapping> <servlet-name>mvc</servlet-name> <url-pattern>/*</url-pattern> </servlet-mapping> </web-app>
3. 下面是整个工程的目录结构:
4. 自定义注解:
@Controller:
@Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface Controller { String value() default ""; }
@Service:
@Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface Service { String value() default ""; }
@AutoWired:
@Target(ElementType.FIELD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface Autowired { String value() default ""; }
@RequestMapping:
@Target(ElementType.FIELD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface Autowired { String value() default ""; }
@RequestParam:
@Target(ElementType.PARAMETER) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RequestParam { String value() default ""; }
5.自己封装的Handler:
public class Handler { protected Object controller; protected Method method; protected Pattern pattern; protected Map<String,Integer> paramIndexMap; public Handler(Object controller, Method method, Pattern pattern) { this.controller = controller; this.method = method; this.pattern = pattern; this.paramIndexMap = new HashMap<>(); putParamIndexMapping(method); } private void putParamIndexMapping(Method method) { //获取方法中加了注解的参数 Annotation[][] annotations = method.getParameterAnnotations(); for (int i =0; i < annotations.length;i++){ for (Annotation annotation : annotations[i]){ if (annotation instanceof RequestParam){ String paramName = ((RequestParam) annotation).value(); if (!StringUtils.isBlank(paramName)){ paramIndexMap.put(paramName,i); } } } } //获取方法中的我request和response的参数 Class<?>[] paramTypes = method.getParameterTypes(); for (int i = 0; i < paramTypes.length; i++){ Class<?> paramType = paramTypes[i]; if (paramType == HttpServletRequest.class || paramType == HttpServletResponse.class){ paramIndexMap.put(paramType.getName(),i); } } } }
6. 自己封装的DispatcherServlet:
@Slf4j public class DispatcherServlet extends HttpServlet { private static final long serialVersionUID = 1L; private Properties contextConfig = new Properties(); private List<String> classNames = new ArrayList<>(); private Map<String, Object> iocMap = new HashMap<>(); private List<Handler> handlerMapping = new ArrayList<>(); @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException { this.doPost(req, resp); } @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { //等待请求 try { doDispatch(req, resp); } catch (Exception exception) { resp.getWriter().write("500 Exception"); log.error("500 Exception. Cause: {}", exception.getMessage()); exception.printStackTrace(); } } private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception { Handler handler = getHandler(req); if (handler == null) { //没有匹配上,404 log.info("404 Not Found"); resp.getWriter().write("404 Not Found"); return; } //获取参数列表 Class<?>[] parameterTypes = handler.method.getParameterTypes(); //保存所有需要自动赋值的参数值 Object[] parameterValues = new Object[parameterTypes.length]; Map<String, String[]> parameterMap = req.getParameterMap(); for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) { String value = Arrays.toString(entry.getValue()).replaceAll("\\[|\\]", "").replaceAll("/+", "/"); log.info(value); //如果找到了匹配的值,就填充 if (!handler.paramIndexMap.containsKey(entry.getKey())) { continue; } Integer index = handler.paramIndexMap.get(entry.getKey()); parameterValues[index] = convert(parameterTypes[index], value); } //设置方法中的request对象和response对象 Integer reqIndex = handler.paramIndexMap.get(HttpServletRequest.class.getName()); Integer respIndex = handler.paramIndexMap.get(HttpServletResponse.class.getName()); parameterValues[reqIndex] = req; parameterValues[respIndex] = resp; handler.method.invoke(handler.controller, parameterValues); } private Object convert(Class<?> parameterType, String value) { if (parameterType == Integer.class) { return Integer.valueOf(value); } return value; } private Handler getHandler(HttpServletRequest req) { if (handlerMapping.isEmpty()) { return null; } String requestURI = req.getRequestURI(); String contextPath = req.getContextPath(); requestURI = requestURI.replace(contextPath, "").replaceAll("/+", "/"); for (Handler handler : handlerMapping) { Matcher matcher = handler.pattern.matcher(requestURI); if (!matcher.matches()) { continue; } return handler; } return null; } @Override public void init(ServletConfig config) { //从这里开始启动: //加载配置文件 loadConfig(config.getInitParameter("contextConfigLocation")); //扫描相关类 doScanner(contextConfig.getProperty("scanPackage")); //初始化相关类 try { doInstance(); } catch (Exception exception) { log.error("Execute doInstance method fail."); exception.printStackTrace(); } //自动注入 doAutowired(); //初始化HandlerMapping initHandlerMapping(); } private void initHandlerMapping() { if (iocMap.isEmpty()) { return; } for (Map.Entry<String, Object> entry : iocMap.entrySet()) { Class<?> clazz = entry.getValue().getClass(); if (!clazz.isAnnotationPresent(Controller.class)) { continue; } String baseUrl = ""; if (clazz.isAnnotationPresent(RequestMapping.class)) { RequestMapping requestMapping = clazz.getAnnotation(RequestMapping.class); baseUrl = requestMapping.value(); } //扫描所有的公共方法 for (Method method : clazz.getMethods()) { if (!method.isAnnotationPresent(RequestMapping.class)) { continue; } RequestMapping requestMapping = method.getAnnotation(RequestMapping.class); String regex = ("/" + baseUrl + requestMapping.value()).replaceAll("/+", "/"); Pattern pattern = Pattern.compile(regex); handlerMapping.add(new Handler(entry.getValue(), method, pattern)); log.info("Mapping: {}.{}", regex, method); } } } private void doAutowired() { if (iocMap.isEmpty()) { return; } //循环所有的类,对需要自动赋值的属性进行赋值 for (Map.Entry<String, Object> entry : iocMap.entrySet()) { Field[] fields = entry.getValue().getClass().getDeclaredFields(); for (Field field : fields) { if (!field.isAnnotationPresent(Autowired.class)) { continue; } Autowired autowired = field.getAnnotation(Autowired.class); String beanName = autowired.value(); if (beanName != null) { beanName = beanName.trim(); } if (StringUtils.isBlank(beanName)) { beanName = field.getType().getName(); } field.setAccessible(true); try { field.set(entry.getValue(), iocMap.get(beanName)); } catch (IllegalAccessException e) { log.error("AutoWired fail,beanName: {}", beanName); e.printStackTrace(); continue; } } } } private void doInstance() throws Exception { if (classNames.isEmpty()) { return; } for (String className : classNames) { Class<?> clazz = Class.forName(className); //如果自定义了名字,就优先使用自己的名字,否则默认是小写(这里就不默认首字母为小写了 if (clazz.isAnnotationPresent(Controller.class)) { Controller controller = clazz.getAnnotation(Controller.class); String beanName = controller.value(); if (StringUtils.isBlank(beanName)) { beanName = clazz.getName().toLowerCase(); } Object instance = clazz.newInstance(); iocMap.put(beanName, instance); } else if (clazz.isAnnotationPresent(Service.class)) { Service service = clazz.getAnnotation(Service.class); String beanName = service.value(); if (StringUtils.isBlank(beanName)) { beanName = clazz.getName().toLowerCase(); } Object instance = clazz.newInstance(); iocMap.put(beanName, instance); //根据接口类型来赋值 for (Class<?> clazzInterface : clazz.getInterfaces()) { iocMap.put(clazzInterface.getName(), instance); } } else { continue; } } } private void doScanner(String scanPackage) { URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/")); File classDir = new File(url.getFile()); for (File file : classDir.listFiles()) { if (file.isDirectory()) { doScanner(scanPackage + "." + file.getName()); } else { String className = scanPackage + "." + file.getName().replace(".class", ""); classNames.add(className); } } } private void loadConfig(String location) { InputStream inputStream = this.getClass().getResourceAsStream(location); try { contextConfig.load(inputStream); } catch (IOException e) { log.error("Load fail, location: {}", location); e.printStackTrace(); } finally { if (inputStream != null) { try { inputStream.close(); } catch (IOException e) { log.error("Close fail, inputStream: {}", inputStream); e.printStackTrace(); } } } } }
这个类就是最核心的类,它做了SpringMVC的事情。
7.下面是验证自己SpringMVC是否可用的时候了,自己写了service和controller:
7.1 service:
public class DemoServiceImpl implements IDemoService { @Override public String get(String name) { return "my name is " + name; } }
7.2 controller:
@Controller @RequestMapping("/demo") @Slf4j public class DemoController { @Autowired IDemoService service; @RequestMapping("/get") public void get(HttpServletRequest req, HttpServletResponse resp, @RequestParam("name") String name) { String res = service.get(name); try { resp.setContentType("text/html;charset=UTF-8"); resp.getWriter().println(res); } catch (IOException e) { log.info(e.getMessage()); e.printStackTrace(); } } }
再结合开头贴出来的图片,验证了自己的这个SpringMVC是可以使用的。
四 最后
这里只要实现了SpringMVC最简单的功能而已。这只是一个加深自己对SpringMVC的mapping映射流程的理解而已,真正的SpringMVC当然远不止如此简单。
Demo的github地址:https://github.com/Happy-Ape/Spring