一文搞懂责任链模式

474 阅读9分钟

背景

以采购单分级审批举例

image.png

我们实现了类似下面的代码

image.png

但存在下面的问题:

  • PurchaseRequestHandler类较为庞大,各个级别的审批方法都集中在一个类中,违反了单一职责原则
  • 修改或者增加审批流程,只能通过修改源代码来实现,客户端无法定制审批流程

概述

将请求发送者和接收者解耦,让多个对象都有机会接收请求,将这些对象连接成一条链,并且沿着这条链传递请求,直到有对象处理它为止。

职责链模式结构图

适用场景

  • 多个对象可以处理同一个请求,具体哪些对象处理,等到请求运行时刻才知道
  • 动态组织一组对象处理请求。客户端可以动态创建职责链来处理请求,还可以改变链中处理者之间的先后次序

缺点

  • 职责链太长,性能差
  • 如果建链不当,可能会造成循环调用,将导致系统陷入死循环

实际应用

Spring MVC

先回顾下 Spring MVC 执行流程:

image.png

  • 用户通过浏览器发送一个HTTP请求到服务器,web服务器接收此请求,若匹配则转交给DispatcherServlet
  • DispatcherServlet 拦截此请求后会调用 HandlerMappingHandlerMapping 根据请求URL找到具体的 Handler(就是我们自己定义的Controller) HandlerExecutionChain(拦截器链,包含了很多定义的 Handlerlnterceptor),然后返回给 DispatcherServlet
  • DispatcherServlet 调用 HandlerAdapterHandlerAdapterHandler 进行封装,然后调用统一的 handle 方法执行 handle,handle 执行完成后返回 ModelAndView
  • DispatcherServlet 解析得到 ViewName,然后调用 ViewReslover,将逻辑视图名解析为真正的视图对象 View。
  • DispatcherServlet 将 model 数据填充到 view ,得到最终的 Responose 返回给用户

其中,责任链模式主要体现在第二步,下面我们具体分析下。

Handlerlnterceptor

拦截器链充当处理者角色,代码很简单,就是允许插入自定义的预处理和后处理逻辑。

public interface HandlerInterceptor {

    /**
     * 在实际处理器执行之前调用,用于预处理。
     * 如果返回true,则继续执行后续的拦截器和处理器;如果返回false,则中断后续处理。
     */
    default boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        return true;
    }

    /**
     * 在实际处理器执行之后,但在视图被渲染之前调用,用于后处理。
     * 可以在此方法中对ModelAndView进行操作。
     */
    default void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
    }

    /**
     * 在请求完成之后调用,无论是否发生异常。
     * 用于清理资源等。
     */
    default void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
    }

}

HandlerExecutionChain

负责具体责任链的调用

/**
 * 处理程序执行链,由处理程序对象和任何处理程序拦截器组成。
 * 由HandlerMapping的HandlerMapping. gethandler返回
 */
public class HandlerExecutionChain {

    /**
     * 处理器对象,通常是Controller
     */
	private final Object handler;

    /**
     * 拦截器数组, 用于在创建`HandlerExecutionChain`时静态地指定拦截器。
     * 创建时明确指定哪些拦截器应该被包含在责任链中。
     */
	@Nullable
	private HandlerInterceptor[] interceptors;

    /**
     * 拦截器列表,用于在运行时动态地添加拦截器。
     * 允许你在`HandlerExecutionChain`创建后,通过调用`addInterceptor`方法动态地添加拦截器
     */
	@Nullable
	private List<HandlerInterceptor> interceptorList;

    /**
     * 记录当前拦截器的索引,用于afterCompletion方法
     */
	private int interceptorIndex = -1;

	public Object getHandler() {
		return this.handler;
	}

	public void addInterceptor(HandlerInterceptor interceptor) {
		initInterceptorList().add(interceptor);
	}

	public void addInterceptors(HandlerInterceptor... interceptors) {
		if (!ObjectUtils.isEmpty(interceptors)) {
			CollectionUtils.mergeArrayIntoCollection(interceptors, initInterceptorList());
		}
	}

	private List<HandlerInterceptor> initInterceptorList() {
		if (this.interceptorList == null) {
			this.interceptorList = new ArrayList<>();
			if (this.interceptors != null) {
				// An interceptor array specified through the constructor
				CollectionUtils.mergeArrayIntoCollection(this.interceptors, this.interceptorList);
			}
		}
		this.interceptors = null;
		return this.interceptorList;
	}

	@Nullable
	public HandlerInterceptor[] getInterceptors() {
		if (this.interceptors == null && this.interceptorList != null) {
			this.interceptors = this.interceptorList.toArray(new HandlerInterceptor[0]);
		}
		return this.interceptors;
	}


	/**
     * 应用所有拦截器的preHandle方法,这个方法按顺序调用每个拦截器的preHandle方法。
     * 如果某个拦截器的preHandle方法返回false,则不会调用后续的拦截器和处理器,并且会触发afterCompletion方法。
     */
	boolean applyPreHandle(HttpServletRequest request, HttpServletResponse response) throws Exception {
		HandlerInterceptor[] interceptors = getInterceptors();
		if (!ObjectUtils.isEmpty(interceptors)) {
			for (int i = 0; i < interceptors.length; i++) {
				HandlerInterceptor interceptor = interceptors[i];
				if (!interceptor.preHandle(request, response, this.handler)) {
					triggerAfterCompletion(request, response, null);
					return false;
				}
				this.interceptorIndex = i;
			}
		}
		return true;
	}

	/**
     * 应用所有拦截器的postHandle方法,在处理器执行之后,这个方法按逆序调用每个拦截器的postHandle方法。
     */
	void applyPostHandle(HttpServletRequest request, HttpServletResponse response, @Nullable ModelAndView mv)
			throws Exception {

		HandlerInterceptor[] interceptors = getInterceptors();
		if (!ObjectUtils.isEmpty(interceptors)) {
			for (int i = interceptors.length - 1; i >= 0; i--) {
				HandlerInterceptor interceptor = interceptors[i];
				interceptor.postHandle(request, response, this.handler, mv);
			}
		}
	}

	/**
     * 触发所有拦截器的afterCompletion方法,无论请求是否成功处理,这个方法都会按逆序调用每个拦截器的afterCompletion方法。
     * 这个方法通常用于资源清理。
     */
	void triggerAfterCompletion(HttpServletRequest request, HttpServletResponse response, @Nullable Exception ex)
			throws Exception {

		HandlerInterceptor[] interceptors = getInterceptors();
		if (!ObjectUtils.isEmpty(interceptors)) {
			for (int i = this.interceptorIndex; i >= 0; i--) {
				HandlerInterceptor interceptor = interceptors[i];
				try {
					interceptor.afterCompletion(request, response, this.handler, ex);
				}
				catch (Throwable ex2) {
					logger.error("HandlerInterceptor.afterCompletion threw exception", ex2);
				}
			}
		}
	}
    // ...
}

doDispatcher()

查看org.springframework.web.servlet.DispatcherServlet#doDispatch代码

protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
    try {
        ModelAndView mv = null;
        try {
            // 1. 根据request找到Handler
            mappedHandler = getHandler(processedRequest);
            if (mappedHandler == null || mappedHandler.getHandler() == null) {
                noHandlerFound(processedRequest, response);
                return;
            }

            // 2. 根据Handler找到HandlerAdapter
            HandlerAdapter ha = getHandlerAdapter(mappedHandler.getHandler());

            // 执行相应Interceptor的preHandle
            if (!mappedHandler.applyPreHandle(processedRequest, response)) {
                return;
            }

            // 3. HandlerAdapter使用Handler处理请求
            mv = ha.handle(processedRequest, response, mappedHandler.getHandler());

            // 当view为空时(比如,Handler返回值为void),根据request设置默认view
            applyDefaultViewName(request, mv);
            // 执行相应Interceptor的postHandle
            mappedHandler.applyPostHandle(processedRequest, response, mv);
        } catch (Exception ex) {
        }
        // 4. 处理返回结果。包括处理异常、渲染页面、发出完成通知触发Interceptor的afterCompletion
        processDispatchResult(processedRequest, response, mappedHandler, mv, dispatchException);
    } catch (Exception ex) {
       triggerAfterCompletion(processedRequest, response, mappedHandler, ex);
    } catch (Error err) {
       triggerAfterCompletion(processedRequest, response, mappedHandler, ex);
    } finally {
    }
}

private void processDispatchResult(HttpServletRequest request, HttpServletResponse response,
                                    @Nullable HandlerExecutionChain mappedHandler, @Nullable ModelAndView mv,
                                    @Nullable Exception exception) throws Exception {

    boolean errorView = false;

    // 包括处理异常
    if (exception != null) {
        if (exception instanceof ModelAndViewDefiningException) {
            logger.debug("ModelAndViewDefiningException encountered", exception);
            mv = ((ModelAndViewDefiningException) exception).getModelAndView();
        } else {
            Object handler = (mappedHandler != null ? mappedHandler.getHandler() : null);
            mv = processHandlerException(request, response, handler, exception);
            errorView = (mv != null);
        }
    }

    // Did the handler return a view to render?
    if (mv != null && !mv.wasCleared()) {
        // 渲染页面
        render(mv, request, response);
    } else {
    }

    if (mappedHandler != null) {
        // Exception (if any) is already handled..
        mappedHandler.triggerAfterCompletion(request, response, null);
    }
}

protected void render(ModelAndView mv, HttpServletRequest request, HttpServletResponse response) throws Exception {

    View view;
    String viewName = mv.getViewName();
    if (viewName != null) {
        // We need to resolve the view name.
        // 解析视图名,返回View对象
        view = resolveViewName(viewName, mv.getModelInternal(), locale, request);
        if (view == null) {
            throw new ServletException("Could not resolve view with name '" + mv.getViewName() +
                    "' in servlet with name '" + getServletName() + "'");
        }
    } else {
        // No need to lookup: the ModelAndView object contains the actual View object.
        view = mv.getView();
    }
    try {
        view.render(mv.getModelInternal(), request, response);
    } catch (Exception ex) {
    }
}

protected View resolveViewName(String viewName, @Nullable Map<String, Object> model,
                                Locale locale, HttpServletRequest request) throws Exception {

    if (this.viewResolvers != null) {
        for (ViewResolver viewResolver : this.viewResolvers) {
            View view = viewResolver.resolveViewName(viewName, locale);
            if (view != null) {
                return view;
            }
        }
    }
    return null;
}

doDispatcher方法处理流程图

image.png

debug过程

public class MyInterceptor implements HandlerInterceptor {
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        System.out.println("MyInterceptor: preHandle");
        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
        System.out.println("MyInterceptor: postHandle");
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        System.out.println("MyInterceptor: afterCompletion");
    }
}
@Configuration
public class WebConfig implements WebMvcConfigurer {
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new MyInterceptor());
    }
}

@Controller
public class TestController {
    @Autowired
    private RedisTemplate redisTemplate;

    @GetMapping("/test")
    public void test() {
        System.out.println("Hello, World!");
    }
}
前置

1722339528724_F18E9506-9275-46ab-9FEB-D8CBF0EBA344.png

前置执行完成,此时this.interceptorIndex = 2

handle

6a072174607c9881b8fc8fc39d5e41a1.png

2276ded1e027d45f52b9d55f31ac16fa.png

后置

65575b894025c60eaa7600396e017628.png

96a8e25732263eab5d1a21b436640c06.png 后置逻辑执行完毕,此时this.interceptorIndex = 2

完成后

小结

  • 通过数组存储注册在Spring中的HandlerInterceptor,然后通过interceptorIndex作为指针去遍历责任链数组按顺序调用处理者。
  • 使用的是责任链的一种变体(SpringMVC中如果一个拦截器的preHandle方法执行错误,请求处理链将被终止,不会继续往下执行)

Spring AOP

已在 Spring AOP 原理分析中详细分析,此处只做总结

  • Spring 根据 @Before@After@AfterRetuming@AfterThrowing 这些注解, 往集合里面加入对应的 Spring 提供的 Methodlnterceptol 实现,如果你没用 @Before 集合里就没有 MethodBeforeAdviceInterceptor
  • 然后通过一个对象 CglilbMethodlnvocation 将这个集合封装起来,紧接着调用这个对象的 proceed 方法, 具体是利用 currentInterceptorIndex 下标,利用递归顺序地执行集合里面的 MethodInterceptor,下图是调用链的堆栈,可以很直观地看到调用的顺序(从下往上看):

image.png

ps: 上图 chain 集合我们看到第一个被调用的是ExposeInvocationInterceptor ,目的是将创建的 CglilbMethodlnvocation 存入 ThreadLocal 中,方便后面其他 Interceptor 调用的时候能得到这个对象,进行一些调用。见词知意:Expose 暴露

image.png

工作中如何应用

促销计价接口应用了责任链模式,下面简单介绍下。类图如下:

image.png 参考代码

/**
 * 链条初始化配置
 */
@Slf4j
@Component
@Configuration
public class ChainConfig {

    /**
     * 前置节点
     */
    private static final String BEFORE_NODE_NAME = "beforeNode";

    /**
     * 单品节点
     */
    private static final String SINGLE_NODE_NAME = "singleNode";
    

    /**
     * 联报节点
     */
    private static final String UNION_NODE_NAME = "unionNode";

    /**
     * 满减节点
     */
    private static final String FULL_NODE_NAME = "fullNode";

    /**
     * 优惠券节点
     */
    private static final String COUPON_NODE_NAME = "couponNode";

    /**
     * 抵扣码节点
     */
    private static final String DEDUCTION_NODE_NAME = "deductionNode";

    /**
     * 后置节点
     */
    private static final String AFTER_NODE_NAME = "afterNode";

    /**
     * 价格节点集合
     */
    private static List<String> priceNodeList = new ArrayList<>();

    static {
        priceNodeList.add(BEFORE_NODE_NAME);
        priceNodeList.add(SINGLE_NODE_NAME);
        priceNodeList.add(UNION_NODE_NAME);
        priceNodeList.add(FULL_NODE_NAME);
        priceNodeList.add(COUPON_NODE_NAME);
        priceNodeList.add(DEDUCTION_NODE_NAME);
        priceNodeList.add(AFTER_NODE_NAME);
    }

    @Resource
    private ApplicationContext applicationContext;

    @Bean(name = "promotionPriceChain")
    public PromotionPriceChain initChain() {
        log.info("初始化计价节点数据开始");

        PromotionPriceChain promotionPriceChain = new PromotionPriceChain();

        LinkedHashMap<String, PromotionPriceNode> priceNodes = getPriceNodes();

        promotionPriceChain.reloadNode(priceNodes);

        log.info("初始化计价节点数据结束");

        return promotionPriceChain;
    }


    /**
     * 获取价格节点集合
     *
     * @author chenwenning
     * @created: 2021/2/18 10:59 上午
     */
    private LinkedHashMap<String, PromotionPriceNode> getPriceNodes() {

        log.info("开始构建计价节点 {}", JSON.toJSONString(priceNodeList));
        LinkedHashMap<String, PromotionPriceNode> priceNodes = new LinkedHashMap<>();

        for (String nodeName : priceNodeList) {
            PromotionPriceNode promotionPriceNode = applicationContext.getBean(nodeName, PromotionPriceNode.class);
            priceNodes.put(nodeName, promotionPriceNode);
        }
        log.info("计价节点构建成功 {}", JSON.toJSONString(priceNodes));
        return priceNodes;
    }
}
/**
 * 选取链条 执行链条
 */
@Component
@Slf4j
public class ChainProcessor {

    @Resource(name = "promotionPriceChain")
    private PromotionPriceChain promotionPriceChain;


    public Result<PromotionCalculateResultRespVO> execute(PromotionPriceReqVO priceReq){
        PromotionPriceContext context = new PromotionPriceContext();
        context.setContextId(BaseUniqueIdUtil.generate());

        log.info(" 开始执行计价接口 contextId:{} 计价请求参数:{}", context.getContextId(), JSON.toJSONString(priceReq));

        promotionPriceChain.execute(priceReq, context);

        log.info(" 执行计价接口结束 contextId:{} 计价返回参数 :{}", context.getContextId(),JSON.toJSONString(context.getCalculateResult()));
        return Result.success(context.getCalculateResult());
    }
}

/**
 * 促销执行处理链条
 */
@Slf4j
public class PromotionPriceChain {


    /**
     * 链条
     */
    protected LinkedHashMap<String, PromotionPriceNode> nodes = new LinkedHashMap<>();

    /**
     * 执行链
     *
     * @param promotionPriceReq
     * @param context
     */
    public void execute(PromotionPriceReqVO promotionPriceReq, PromotionPriceContext context) {
        StringBuilder executeLog = new StringBuilder("execute : ");

        for (Map.Entry<String, PromotionPriceNode> nodeEntry : nodes.entrySet()) {
            //  当前节点
            String currentNodeName = nodeEntry.getKey();
            //  记录节点执行顺序
            executeLog.append(">").append(currentNodeName);

            //  执行计算节点
            nodeEntry.getValue().calculate(promotionPriceReq, context);

           if (context.getIsInterrupt()) {
               log.info("链条执行顺序: {}", executeLog);
                //  责任链被中断,break循环
                break;
            }
        }
        log.info("链条执行顺序: {}", executeLog);
    }

    /**
     * 装配
     *
     * @param nodes 处理节点
     */
    public void reloadNode(LinkedHashMap<String, PromotionPriceNode> nodes) {
        this.nodes = nodes;
    }
}

/**
 * 促销上下文
 */
@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class PromotionPriceContext {

    /**
     * 计价结果
     */
    private PromotionCalculateResultRespVO calculateResult;

    /**
     * 是否存在秒杀产品
     */
    private Boolean isSecKill = Boolean.FALSE;

    /**
     * 秒杀产品集合
     */
    private List<Integer> secKillProductList;

    /**
     * 秒杀套餐集合
     */
    private List<Integer> secKillPackageList;

    /**
     *  是否中断执行
     */
    private Boolean isInterrupt = Boolean.FALSE;

    /**
     * 计价上下文id
     */
    private Long contextId;
}
public interface PromotionPriceNode {

    /**
     * 执行计算
     */
    void calculate(PromotionPriceReqVO promotionPriceReq, PromotionPriceContext context);

    /**
     * 获取产品的价格 价格传递 单品->联报->优惠券
     */
    default Map<Integer, BigDecimal> productPriceMap(PromotionPriceContext context){
    }

    /**
     * 均摊套餐优惠金额
     */
    default List<ProductPackageDetailCalculateResultRespVO> calculatePackageProductApportionAmount(
            List<ProductPackageDetailCalculateResultRespVO> packageDetailCalculateResultRespList
            , BigDecimal packageTotalPrice, BigDecimal discountAmount) {
    }
}

促销链条实现类截图

image.png

参考链接