简单代码实现springMVC中DispatcherServlet基本功能

225 阅读5分钟

1.servlet简介

​ servlet是一种运行在web服务器环境中的java程序,常见的使用形式是定义一个java类,继承HttpServlet类,重写HttpServlet的init()和doGet/doPost方法,之后在web.xml定义这个类的servlet响应路径,将该java类编译后的class文件放入tomcat服务器的/webapps/ROOT/WEB-INF/classes文件夹下,将web.xml文件放入/webapps/ROOT/WEB-INF文件夹下,启动tomcat,访问映射路径后得到结果即访问servlet成功。

servlet的工作流程

​ 想搞清楚servlet的工作流程前,必须要了解servlet在web服务器中的流转逻辑,这里以tomcat为例,tomcat服务器主要实现了三个功能:

  1. 建立通道监听端口的请求数据

  2. 创建了servlet容器,规定了servlet执行规则(init->service->(doGet||doPost))

  3. 构建端口请求对象和servlet响应地址的映射关系

所以一个完整的请求流程是:

  1. 用户在web环境中发起一个请求信息
  2. tomcat端口监听对象捕获到请求后会创建request和response对象,将请求报文封装到requset中
  3. request对象通过请求url映射关系找到对应servlet对象,若是首次访问该servlet则执行init方法
  4. 之后执行service方法,根据resquest中的请求类型进一步执行doGet或doPost一类请求方法
  5. 我们在doGet或doPost方法中处理完逻辑后,会将数据封装到response对象中,返回给用户完成请求流程
一个servlet实例

这里使用springboot构建web项目环境

定义一个servlet类:

public class SimpleServlet extends HttpServlet {

    @Override
    public void init() throws ServletException {
        System.out.println("servlet执行");
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        System.out.println("doGet执行");
        doPost(req,resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        System.out.println("doPost执行");
        resp.setContentType("text/html;charset=UTF-8");
        resp.getWriter().append("你好");
    }

    @Override
    protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        super.service(req, resp);
        System.out.println("service执行");

    }
}

创建servlet注册类

@Configuration
public class ServletConfig  {
    @Bean
    public ServletRegistrationBean servletRegistrationBean() {
        ServletRegistrationBean servletRegistrationBean =
                new ServletRegistrationBean(new SimpleServlet(),"/v4/*");
        return servletRegistrationBean;
    }
}

启动springboot项目,访问http://localhost/v4,页面显示'你好',同时项目控制端打印内容:

servlet执行
doGet执行
doPost执行
service执行

2.spring中对servlet的应用

​ 在上述servlet实例中,实现了对于**/v4** url的请求拦截,如果要拦截其他url的请求我们要创建另外一个servlet了,那么能不能只创建一个servlet完成对所有请求地址的拦截?答案是肯定的,将servlet的拦截地址改为**/***就可以了;在此基础上能不能做一些扩展:根据不同的请求url执行不同的逻辑,当然也可以,spring的MVC模块就是这么做的,中间涉及到了DispatcherServlet、@Controller、@RequestMapping等我们常用到的内容,下面来简单实现springMVC的基本功能。

一个自定义的mvc映射servlet

我们先简化代码实现springMVC的功能:

项目文件结构如下图:

1.定义MVC中常用的注解

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PWBAutowired {
    String value() default "";
}

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PWBController {
    String value() default "";
}

@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PWBRequestMapping {
    String value() default "";
}

@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PWBRequestParam {
    String value() default "";
}

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PWBService {
    String value() default  "";
}

2.定义一个通用servlet

@WebServlet(name = "PwbDisparcherServlet", urlPatterns = {
        "/*"
})
public class PwbDisparcherServlet extends HttpServlet {
    //存储扫描到的bean类路径名称
    private  List<String> classNameList = new ArrayList<>();
    //存储bean
    private Map<String,Object> iocMap = new HashMap<>();
    //存储url映射关系
    private Map<String, Method> handlerMap = new HashMap<>();

    @Override
    public void init(ServletConfig cfg) throws ServletException {
        //根据配置读取需要实例化的bean
        doScanner("com.demo");
        //将有注解标注的bean实例化并存储
        doInstance();
        //将被调用的对象进行注入
        doAutowire();
        //将所有映射访问路径进行存储
        doHandlerMapping();
        System.out.println("---------------初始化完成!!!!");
    }


    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
        doDisparcher(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
        doDisparcher(req, resp);
    }
}

这个serlvet的初始化方法包含了以下逻辑:

  1. 扫描包路径获取bean对象信息,存储到classNameList中(doScanner方法)

    private void doScanner(String scanPackage){
        URL url = this.getClass().getClassLoader().getResource(scanPackage.replaceAll("\\.","/"));
        File classFile = new File(url.getFile());
        if(classFile == null){
            throw new RuntimeException(classFile.toString()+"文件路径错误!");
        }
        if(classFile.listFiles() == null){
            throw new RuntimeException("获取不到"+classFile.toString()+"文件路径下的列表数据!");
        }
        for(File file: classFile.listFiles()){
            if(file.isDirectory()){
                //递归
                doScanner(scanPackage+"."+file.getName());
            }else{
                //获取class绝对路径
                if(file.getName().endsWith(".class")&&!file.getName().contains("PWB")&&!file.getName().contains("Application")&&!file.getName().contains("Servlet")){
                    String className = scanPackage+"."+ file.getName().replaceAll(".class","");
                    classNameList.add(className);
                }
            }
    
        }
    }
    
  2. 根据classNameList列表,对bean进行实例化并存储到iocMap中(doInstance方法)

    private void doInstance(){
        try {
            for(String className: classNameList){
                //过滤
                if (!className.contains(".")) {
                    continue;
                }
                //类实例化
                iocMap.put(className, Class.forName(className).newInstance());
            }
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }
    
  3. 根据bean中的@PWBAutowired注解信息,完成对bean中的一类注入(doAutowire方法)

    private void doAutowire() {
        System.out.println("---------------开始注入!!!!");
        try {
            //处理类的依赖注入
            for(Object object: iocMap.values()){
                if(object == null){
                    continue;
                }
                Class clazz = object.getClass();
                //循环类的属性,找要注入的类
                for(Field field : clazz.getDeclaredFields()){
                    if(field.isAnnotationPresent(PWBAutowired.class)){
                        String beanName = field.getAnnotation(PWBAutowired.class).value();
                        if("".equals(beanName)){
                            beanName = field.getType().getName();
                        }
                        //解除限制
                        field.setAccessible(true);
                        field.set(iocMap.get(clazz.getName()), iocMap.get(beanName));
    
                    }
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }
    
  4. 根据bean中的@PWBController和@PWBRequestMapping注解信息,构建url和method的对应关系,存储到handlerMap中(doHandlerMapping方法)

    private void doHandlerMapping() {
        for(Map.Entry bean: iocMap.entrySet()){
            Class<?> clazz = bean.getValue().getClass();
            //是否有PWBController注解
            if (clazz.isAnnotationPresent(PWBController.class)) {
                //是否有PWBRequestMapping注解
                String mappingUrl = "";
                if (clazz.isAnnotationPresent(PWBRequestMapping.class)) {
                    //获取注解内容
                    mappingUrl = clazz.getAnnotation(PWBRequestMapping.class).value();
                }
                //循环类的方法,记录请求url和对应方法
                Method[] methods = clazz.getMethods();
                for (Method method : methods) {
                    if (method.isAnnotationPresent(PWBRequestMapping.class)) {
                        String url = method.getAnnotation(PWBRequestMapping.class).value();
                        url = (mappingUrl + url).replaceAll("/+", "/");
                        handlerMap.put(url, method);
                    }
                }
            }
        }
    }
    
  5. 根据request中的请求url,匹配method,通过反射执行目标方法(doDisparcher方法)

    private void doDisparcher(HttpServletRequest req, HttpServletResponse resp) {
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath,"").replaceAll("/+","/");
        try {
            if(!this.handlerMap.containsKey(url)){
                resp.getWriter().append("-------------404 NOT FOUND!");
                return;
            }
            //调用对应方法
            Method method = (Method) handlerMap.get(url);
            //获取实际请求参数
            Map<String, String[]> params = req.getParameterMap();
            //请求方法的入参类型数组
            Class<?> [] paramsTypes = method.getParameterTypes();
            //处理后的参数容器
            Object[] resultArray = new Object[paramsTypes.length];
            //循环参数类型数组
            for (int i=0; i < paramsTypes.length; i++){
                Class<?> paramType = paramsTypes[i];
                //参数为request直接返回
                if(paramType == HttpServletRequest.class){
                    resultArray[i] = req;
                    continue;
                //参数为response直接返回
                }else if (paramType == HttpServletResponse.class){
                    resultArray[i] = resp;
                    continue;
                //参数为string类型处理
                } else if (paramType == String.class) {
                    //获取参数对象
                    Parameter parameter = method.getParameters()[i];
                    //参数是否有PWBRequestParam注解
                    if(parameter.isAnnotationPresent(PWBRequestParam.class)){
                        String annoParamName = parameter.getAnnotation(PWBRequestParam.class).value();
                        resultArray[i] = Arrays.toString(params.get(annoParamName));
                    }else{
                        String paramName = method.getParameters()[i].getName();
                        resultArray[i] = Arrays.toString(params.get(paramName));
                    }
                }
            }
            //获取方法对应的类对象
            Object obj = (iocMap.get(method.getDeclaringClass().getName()));
            //反射
            method.invoke(obj, resultArray);
        }  catch (IOException e) {
            e.printStackTrace();
        }catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
    }
    

创建一个对应的请求Controller

@PWBController
@PWBRequestMapping("/v1")
public class IndexController {
    @PWBAutowired
    private IndexService indexService;

    @PWBRequestMapping("/query")
    public void queryString(HttpServletRequest request, HttpServletResponse response, String name){
        String msg = indexService.queryString(name);
        try {
            response.setContentType("text/html;charset=UTF-8");
            response.getWriter().write(msg);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

@PWBService
public class IndexService {
    public String queryString(String msg){
        return "service返回信息:"+msg+"=====";
    }
}

启动项目,访问http://localhost/v1/query