自己动手实现mvc框架

  用过springmvc的可能都知道,要集成springmvc需要在web.xml中加入一个跟随web容器启动的DispatcherServlet,然后由该servlet初始化一些东西,然后所有的web请求都被这个servlet接管。所以自己写mvc的关键就是弄懂这个servlet干了啥。先分析一下springmvc的功能,首先我们写一个接口,就是写一个Controller,然后里面写一个方法,在类或者方法里面使用@RequestMapping直接修饰,当该注解对应的path被请求时,会按照指定格式传入参数并调用该方法,然后按照指定格式将调用的结果写出到向浏览器的输出流中(@ResponseBody),或者转发到jsp,再去由jsp转换生成的servlet去将结果写出到输出流(默认的请求转发),或者重定向到指定的jsp(return "redirect:/test.jsp"等。上面仅仅使用jsp举例子,不代表springmvc只支持使用jsp渲染。但是我们自己写的mvc只是为了演示整个流程和若干细节,并不能全面重写springmvc我也没那个能力重写,所以页面层只用jsp。
  详细分析下写一个mvc的流程:第一:我们也可以使用一个servlet将前端所有请求都接管到一个servlet中去(这里其实filter,servlet都可以实现,比如strus2采用filter接管,springmvc采用servlet,原理大同小异),第二:这个servlet是随着容器自启动,所以需要配置load-on-startup,然后我们在这个servlet的init方法里面可以扫描指定的包(扫描哪些包,可以通过servlet在web.xml中的init-param配置),加载一些注解并将注解配置的属性和类,方法对象的关系保存在一些map中,第三:当页面访问任意后端接口时,最终会经过doGet或者doPost(上层是service方法,为了方便不用service方法),我们可以在这两个方法中根据请求的url路径去找到对应的Controller类和Method对象,然后从请求中拿出参数传入方法需要的参数,得到方法执行的结果,最后根据方法里面指定的返回格式(@ResponseBody这种),或者请求转发,或者重定向做处理。废话不多说,直接上代码。

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://java.sun.com/xml/ns/javaee" xmlns:web="http://java.sun.com/xml/ns/javaee/web-app_2_5.xsd" xsi:schemaLocation="http://java.sun.com/xml/ns/javaee http://java.sun.com/xml/ns/javaee/web-app_3_0.xsd" id="WebApp_ID" version="3.0">
  <display-name>mymvc</display-name>
  <servlet>
    <servlet-name>dispacher</servlet-name>
    <servlet-class>com.rd.servlet.DispatcherServlet</servlet-class>
    <init-param>
        <param-name>package</param-name>
        <param-value>com.rd.controller</param-value>
    </init-param>
    <load-on-startup>1</load-on-startup>
  </servlet>
  <servlet-mapping>
    <servlet-name>dispacher</servlet-name>
    <url-pattern>/</url-pattern>
  </servlet-mapping>
</web-app>
package com.rd.servlet;

import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;

import com.rd.annotation.Path;
import com.rd.annotation.RespJson;
import com.rd.util.JsonUtil;
import com.rd.util.MethodUtil;
import com.rd.util.ScanClassUtil;

/**
 * 类似springmvc的将请求全部纳入控制范围的servlet
 * @author rongdi
 * @date 2017年9月20日 下午1:58:26
 */
public class DispatcherServlet extends HttpServlet {
    
    private static final long serialVersionUID = 1L;
    
    //path和Class的映射
    private final static Map<String, Class<?>> classMap = new HashMap<String, Class<?>>();
    
    //path和Method的映射
    private final static Map<String, Method> methodMap = new HashMap<String, Method>();
    
    //存放被@RespJson修饰的类
    private final static Set<Class<?>> classRespJsons = new HashSet<Class<?>>();
    
    //存放被@RespJson修饰的方法
    private final static Set<Method> methodRespJsons = new HashSet<Method>();
    
    @Override
    public void init(ServletConfig config) throws ServletException {
        System.out.println("---DispatcherServlet初始化开始---");
        //获取web.xml中配置的要扫描的包
        String basePackage = config.getInitParameter("package");
        //配置了多个包
        if (basePackage.indexOf(",")>0) {
            //按逗号进行分隔
            String[] packageNameArr = basePackage.split(",");
            for (String packageName : packageNameArr) {
                add2ClassMap(packageName);
            }
        }else {
            add2ClassMap(basePackage);
        }
        System.out.println("----DispatcherServlet初始化结束---");
    }
    
    /**
     * 将被注解修饰的类
     * @param packageName
     */
    private void add2ClassMap(String packageName){
        Set<Class<?>> setClasses =  ScanClassUtil.getClasses(packageName);
        for (Class<?> clazz :setClasses) {
            String pathAttrValue = null;
            //判断类被注解修饰
            if (clazz.isAnnotationPresent(Path.class)) {
                //获取path的Annotation的实例
                Path pathInstance = clazz.getAnnotation(Path.class);
                //获取Annotation的实例的value属性的值
                pathAttrValue = pathInstance.value();
                if (StringUtils.isNotEmpty(pathAttrValue)) {
                    pathAttrValue = handPathStr(pathAttrValue);
                    classMap.put(pathAttrValue, clazz);
                }
            }
            if(clazz.isAnnotationPresent(RespJson.class)) {
                classRespJsons.add(clazz);
            }
            //判断方法被注解修饰
            Method[] methods = clazz.getMethods();
            for(Method m:methods) {
                //判断方法被注解修饰
                if(m.isAnnotationPresent(Path.class)) {
                    //获取path的Annotation的实例
                    Path pathInstance = m.getAnnotation(Path.class);
                    //获取Annotation的实例的value属性的值
                    String methodPathValue = pathInstance.value();
                    if (StringUtils.isNotEmpty(methodPathValue)) {
                        methodPathValue = handPathStr(methodPathValue);
                        pathAttrValue = handPathStr(pathAttrValue);
                        methodMap.put(pathAttrValue+methodPathValue, m);
                    }
                }
                if(m.isAnnotationPresent(RespJson.class)) {
                    methodRespJsons.add(m);
                }
            }
        }
    }

    /**
     * 处理一下路径,前面后面的斜杠
     * @param pathStr
     * @return
     */
    private String handPathStr(String pathStr) {
        if(pathStr.endsWith("/")) {
            pathStr = pathStr.substring(0,pathStr.length()-1);
        }
        if(!pathStr.startsWith("/")) {
            pathStr = "/"+pathStr;
        }
        return pathStr;
    }
    
    @Override
    public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        
        Map<String,String[]> params = request.getParameterMap();
        String url = request.getRequestURI();
        //当请求根路径是,请求转发到首页
        if("/".equals(url)) {
            request.getRequestDispatcher("/index.jsp").forward(request, response);
            return;
        }
        if("/favicon.ico".equals(url)) {
            return;
        }
        Method m = null;
        if(methodMap.containsKey(url)) {
            m = methodMap.get(url);
        } else {
            //这个过程,其实可以优化,如果存在通配符匹配,不用每次都循环匹配,可以缓存起来,第二次直接用,这里暂时忽略优化问题
            Set<String> urls = methodMap.keySet();
            for(String murl:urls) {
                String reg = "^"+murl.replace("*", ".*?")+"$";
                if(Pattern.matches(reg, url)) {
                    m = methodMap.get(murl);
                    break;
                }
            }
        }
        if(m == null) {
            throw new ServletException("没有找到与路径:"+url+"对应的处理方法");
        }
        
        try {
            /**
             * 这里需要获取参数名,jdk1.8之后可以直接直接反射获取,条件比较恶心,需要开启开关
             * 如下直接使用javassist字节码工具类实现,也可以用asm等其他工具
             */
            String[] paramNames = MethodUtil.getAllParamaterName(m);
            List<String> paramValues = new ArrayList<String>();
            for(String paramName:paramNames) {
                if(params.get(paramName) == null) {
                    paramValues.add(null);
                } else {
                    paramValues.add(params.get(paramName)[0]);
                }
                
            }
            /**
             * 调用方法所在类的默认构造方法,生成执行方法的对象(springmvc里的这个对象是单例的,这里为了省事,每次都new一个出来),
             * 然后执行方法,返回结果
             */
            Class<?> cla = m.getDeclaringClass();
            Object result = m.invoke(cla.newInstance(), paramValues.toArray());
            /**
             * 如果方法返回类型为 void,则该调用结果返回 null,如果返回值为void,则直接跳转到同路径的jsp页面上,
             * 为了简单起见这里后缀写死.jsp,实际上springmvc是支持配置ViewResolver,
             * 可以指定请求转发或者重定向所在的界面层的前缀和后缀。
             */
            if(result == null) {
                request.getRequestDispatcher(url+".jsp").forward(request, response);
            }
            /**
             * 这里springmvc默认是请求转发到jsp
             * 为了方便这里直接根据修饰返回类型的注解,确定用哪种方式序列化,
             */
            if(classRespJsons.contains(cla) || methodRespJsons.contains(m)) {
                result = JsonUtil.serialize(result);
                //输出json到浏览器
                response.getWriter().print(result);
            } else if(result.toString().startsWith("redirect:")) { //重定向
                //去掉前缀就是重定向到的路径,实际上这里不严谨,应该加上一个项目的上下文路径
                response.sendRedirect(result.toString().substring(9)+".jsp");
            } else { //请求转发
                request.getRequestDispatcher(result.toString()+".jsp").forward(request, response);
            }
            
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        doGet(request, response);
    }

     public static void main(String[] args) {
            String url = "/test/a";
            String reg = "^/test/.*?$";
            System.out.println(Pattern.matches(reg, url));
            System.out.println(Long.class);
                    
        }
}
package com.rd.controller;

import java.util.HashMap;
import java.util.Map;

import com.rd.annotation.Path;
import com.rd.annotation.RespJson;

@Path(value="/test")
public class TestController {

    @Path(value="/text")
    public void text() {
    }
    
    @Path(value="/redirect")
    public String redirect(String a) {
        
        return "redirect:/test/redirect";
    }
    
    @Path(value="/json")
    @RespJson
    public Map json(String a) {
        return new HashMap(){{put("a",a);}};
    }
    
}
package com.rd.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
public @interface Path {
    
    //访问的匹配路径
    String value();
   
}
package com.rd.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
public @interface RespJson {
    
}
package com.rd.util;

import java.util.HashMap;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;

public class JsonUtil {
    
    private static ObjectMapper mapper;

    static {
        mapper = new ObjectMapper();
        mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES,
                false);
    }
    
    public static String serialize(Object obj) throws Exception {

        if (obj == null) {
            throw new IllegalArgumentException("obj should not be null");
        }
        return mapper.writeValueAsString(obj);
    }
    
    public static void main(String[] args) throws Exception {
        System.out.println(serialize(new HashMap(){{put("name","zhangsan");}}));
    }
}
package com.rd.util;
import java.lang.reflect.Method;

import javassist.ClassClassPath;
import javassist.ClassPool;  
import javassist.CtClass;  
import javassist.CtMethod;  
import javassist.Modifier;  
import javassist.NotFoundException;  
import javassist.bytecode.CodeAttribute;  
import javassist.bytecode.LocalVariableAttribute;  
import javassist.bytecode.MethodInfo;  
  
/**
 * 使用javassist的方法工具
 * @author rongdi
 * @date 2017年9月20日 上午11:21:29
 */
public class MethodUtil {  
  
    public static String[] getAllParamaterName(Method method)  
        throws NotFoundException {  
        Class<?> clazz = method.getDeclaringClass();  
        ClassPool pool = ClassPool.getDefault();  
        ClassClassPath classPath = new ClassClassPath(MethodUtil.class);  
        pool.insertClassPath(classPath);  
        CtClass clz = pool.get(clazz.getName());  
        CtClass[] params = new CtClass[method.getParameterTypes().length];  
        for (int i = 0; i < method.getParameterTypes().length; i++) {  
            params[i] = pool.getCtClass(method.getParameterTypes()[i].getName());  
        }  
        CtMethod cm = clz.getDeclaredMethod(method.getName(), params);  
        MethodInfo methodInfo = cm.getMethodInfo();  
        CodeAttribute codeAttribute = methodInfo.getCodeAttribute();  
        LocalVariableAttribute attr = (LocalVariableAttribute) codeAttribute  
            .getAttribute(LocalVariableAttribute.tag);  
        int pos = Modifier.isStatic(cm.getModifiers()) ? 0 : 1;  
        String[] paramNames = new String[cm.getParameterTypes().length];  
        for (int i = 0; i < paramNames.length; i++) {  
            paramNames[i] = attr.variableName(i + pos);  
        }  
        return paramNames;  
    }  
  
}  
package com.rd.util;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.Properties;

/**
 * @author rongdi
 * @date 2017年9月16日 下午3:36:08
 */
public class PropertyUtil {

    private static Properties props;

    synchronized static private void loadProps() {
        props = new Properties();
        InputStream in = null;
        
        try {
            try {
                String path = getJarDir()+"/application.properties";
                in = new FileInputStream(path);
            } catch(Exception e){
                in = PropertyUtil.class.getClassLoader().getResourceAsStream("application.properties");
            }
            props.load(in);
        } catch (FileNotFoundException e) {
            //logger.error("application.properties文件未找到");
        } catch (IOException e) {
            //logger.error("出现IOException");
        } finally {
            try {
                if (null != in) {
                    in.close();
                }
            } catch (IOException e) {
                //logger.error("application.properties文件流关闭出现异常");
            }
        }
    }

    public static String getProperty(String key) {
        if (null == props) {
            loadProps();
        }
        return props.getProperty(key);
    }

    public static String getProperty(String key, String defaultValue) {
        if (null == props) {
            loadProps();
        }
        return props.getProperty(key, defaultValue);
    }

    /**
     * 获取jar绝对路径
     * 
     * @return
     */
    public static String getJarPath() {
        File file = getFile();
        if (file == null)
            return null;
        return file.getAbsolutePath();
    }

    /**
     * 获取jar目录
     * 
     * @return
     */
    public static String getJarDir() {
        File file = getFile();
        if (file == null)
            return null;
        return getFile().getParent();
    }

    /**
     * 获取jar包名
     * 
     * @return
     */
    public static String getJarName() {
        File file = getFile();
        if (file == null)
            return null;
        return getFile().getName();
    }

    /**
     * 获取当前Jar文件
     * 
     * @return
     */
    private static File getFile() {
        // 关键是这行...
        String path = PropertyUtil.class.getProtectionDomain().getCodeSource()
                .getLocation().getFile();
        try {
            path = java.net.URLDecoder.decode(path, "UTF-8"); // 转换处理中文及空格
        } catch (java.io.UnsupportedEncodingException e) {
            return null;
        }
        return new File(path);
    }
}
package com.rd.util;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * 扫描字节码工具类
 * @author rongdi
 * @date 2017年9月19日 下午4:13:19
 */
public class ScanClassUtil {

    /**
     * 从包package中获取所有的Class
     * @param pack
     * @return
     */
    public static Set<Class<?>> getClasses(String pack) {

        // 第一个class类的集合
        Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageName = pack;
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(
                    packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()) {
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    System.err.println("file类型的扫描");
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    findAndAddClassesInPackageByFile(packageName, filePath,recursive, classes);
                } else if ("jar".equals(protocol)) {
                    // 如果是jar包文件
                    // 定义一个JarFile
                    System.err.println("jar类型的扫描");
                    JarFile jar;
                    try {
                        // 获取jar
                        jar = ((JarURLConnection) url.openConnection())
                                .getJarFile();
                        // 从此jar包 得到一个枚举类
                        Enumeration<JarEntry> entries = jar.entries();
                        // 同样的进行循环迭代
                        while (entries.hasMoreElements()) {
                            // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                            JarEntry entry = entries.nextElement();
                            String name = entry.getName();
                            // 如果是以/开头的
                            if (name.charAt(0) == '/') {
                                // 获取后面的字符串
                                name = name.substring(1);
                            }
                            // 如果前半部分和定义的包名相同
                            if (name.startsWith(packageDirName)) {
                                int idx = name.lastIndexOf('/');
                                // 如果以"/"结尾 是一个包
                                if (idx != -1) {
                                    // 获取包名 把"/"替换成"."
                                    packageName = name.substring(0, idx).replace('/', '.');
                                }
                                // 如果可以迭代下去 并且是一个包
                                if ((idx != -1) || recursive) {
                                    // 如果是一个.class文件 而且不是目录
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        // 去掉后面的".class" 获取真正的类名
                                        String className = name.substring(
                                                packageName.length() + 1, name
                                                        .length() - 6);
                                        try {
                                            // 添加到classes
                                            classes.add(Class.forName(packageName + '.'
                                                            + className));
                                        } catch (ClassNotFoundException e) {
                                            e.printStackTrace();
                                        }
                                    }
                                }
                            }
                        }
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return classes;
    }
    
    /**
     * 以文件的形式来获取包下的所有Class
     * 
     * @param packageName
     * @param packagePath
     * @param recursive
     * @param classes
     */
    public static void findAndAddClassesInPackageByFile(String packageName,
            String packagePath, final boolean recursive, Set<Class<?>> classes) {
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            // log.warn("用户定义包名 " + packageName + " 下没有任何文件");
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirfiles = dir.listFiles(new FileFilter() {
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            public boolean accept(File file) {
                return (recursive && file.isDirectory())
                        || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirfiles) {
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(packageName + "."
                        + file.getName(), file.getAbsolutePath(), recursive,
                        classes);
            } else {
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0,
                        file.getName().length() - 6);
                try {
                    // 添加到集合中去
                    //classes.add(Class.forName(packageName + '.' + className));
                    classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));  
                    } catch (ClassNotFoundException e) {
                    // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
                    e.printStackTrace();
                }
            }
        }
    }
}
<properties>
        <project.deploy>deploy</project.deploy>
        <jackson.version>2.5.4</jackson.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>javax.servlet</groupId>
            <artifactId>javax.servlet-api</artifactId>
            <version>3.1.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
            <version>3.4</version>
        </dependency>
        <dependency>
            <groupId>org.javassist</groupId>
            <artifactId>javassist</artifactId>
            <version>3.21.0-GA</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-core</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-annotations</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>3.8.1</version>
            <scope>test</scope>
        </dependency>

    </dependencies>

完整代码百度云地址(不要吐槽要放git啥的,我屌丝一个,不用那么高大上的东西):https://pan.baidu.com/s/1mi24Lbq

posted on 2017-12-10 15:34  码小D  阅读(1214)  评论(0编辑  收藏  举报