springboot+时间轮的简单实现

427 阅读4分钟

1、背景

    主要是为解决有些项目功能需要进行延时处理,采用时间轮的方式可以让延时时间的设置更自由,
延时精度也可以更好的控制

2、框架设计

未命名文件.png

3、模块的实现

3.1 服务注册

服务注册支持两种方式
    1)通过注解扫描的方式
       主要采用启动注解方式扫描对应包路径下任务的处理服务和方法,将方法参数类型作为key,
    服务对象和方法名作为值缓存起来,同时将服务对象注入到springIOC容器中
    
public class CustomerImportBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar {
    @Override
    public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
        AnnotationAttributes mapperScanAttrs = AnnotationAttributes.fromMap(importingClassMetadata.getAnnotationAttributes(EnableTimerWheel.class.getName()));
        if (mapperScanAttrs != null) {
            this.registerBeanDefinitions(importingClassMetadata, mapperScanAttrs, registry, generateBaseBeanName(importingClassMetadata, 0));
        }
    }

    private void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, AnnotationAttributes mapperScanAttrs, BeanDefinitionRegistry registry, String beanName) {
        BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(CustomerInterfaceRegistryPostProcesser.class);
        List<String> basePackages = new ArrayList();
        basePackages.addAll(Arrays.stream(mapperScanAttrs.getStringArray("basePackages")).filter(StringUtils::hasText).collect(Collectors.toList()));
        if (CollectionUtils.isEmpty(basePackages)) {
            basePackages.add(getDefaultBasePackage(importingClassMetadata));
        }
        builder.addPropertyValue("basePackages", basePackages);
        registry.registerBeanDefinition(beanName, builder.getBeanDefinition());
    }

    private static String generateBaseBeanName(AnnotationMetadata importingClassMetadata, int index) {
        return importingClassMetadata.getClassName() + "#" + CustomerImportBeanDefinitionRegistrar.class.getSimpleName() + "#" + index;
    }

    private static String getDefaultBasePackage(AnnotationMetadata importingClassMetadata) {
        return ClassUtils.getPackageName(importingClassMetadata.getClassName());
    }
}
package com.leelen.cloud.timerwheel;

import com.google.common.base.Strings;
import com.leelen.cloud.annotations.TimerWheel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.factory.support.SimpleBeanDefinitionRegistry;
import org.springframework.core.Ordered;
import org.springframework.core.PriorityOrdered;
import org.springframework.core.type.filter.TypeFilter;

import java.lang.annotation.Annotation;
import java.util.Set;

/**
 * @version: 1.00.00
 * @description: 自定义注解处理
 * @copyright: Copyright (c) 2021 立林科技 All Rights Reserved
 * @company: 厦门立林科技有限公司
 * @author: hj
 * @date: 2021-11-06 15:22
 */
@Slf4j
public class CustomerInterfaceRegistryPostProcesser implements PriorityOrdered, BeanDefinitionRegistryPostProcessor {

   private String[] basePackages;


    @Override
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry beanDefinitionRegistry) throws BeansException {

    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory configurableListableBeanFactory) throws BeansException {
        DefaultListableBeanFactory defaultListableBeanFactory = (DefaultListableBeanFactory) configurableListableBeanFactory;
        Set<BeanDefinitionHolder> beanDefinitionHolderSet = this.getBeanDefinitionHolderSet(basePackages);
        for (BeanDefinitionHolder beanDefinitionHolder : beanDefinitionHolderSet) {
            try {
                Class<?> cl = Class.forName(beanDefinitionHolder.getBeanDefinition().getBeanClassName());
                if (cl.getAnnotations() != null && cl.getAnnotations().length > 0) {
                    for (Annotation annotation : cl.getAnnotations()) {
                        if (annotation instanceof TimerWheel) {
                            TimerWheel timerWheel = (TimerWheel) annotation;
                            registerAliaName(configurableListableBeanFactory, beanDefinitionHolder.getBeanName(), cl.getName());
                            if (!Strings.isNullOrEmpty(timerWheel.value())) {
                                registerAliaName(configurableListableBeanFactory, beanDefinitionHolder.getBeanName(), timerWheel.value());
                            }
                            //为了让autowired注解生效
                            RootBeanDefinition rootBeanDefinition = new RootBeanDefinition(cl);
                            defaultListableBeanFactory.registerBeanDefinition(beanDefinitionHolder.getBeanName(), rootBeanDefinition);
                        }
                    }
                }
            } catch (ClassNotFoundException e) {
                log.error("CustomerInterfaceRegistryPostProcesser.postProcessBeanFactory className={}", beanDefinitionHolder.getBeanDefinition().getBeanClassName(), e);
            }
        }
    }

    /**
     * 注册别名
     *
     * @param configurableListableBeanFactory
     * @param beanId
     * @param value
     */
    private void registerAliaName(ConfigurableListableBeanFactory configurableListableBeanFactory, String beanId, String value) {
        if (!configurableListableBeanFactory.containsBeanDefinition(value)) {
            configurableListableBeanFactory.registerAlias(beanId, value);
        }
    }

    /**
     * 获取bean定义
     *
     * @return
     */
    private Set<BeanDefinitionHolder> getBeanDefinitionHolderSet(String[] basePackages) {
        BeanDefinitionRegistry beanDefinitionRegistry = new SimpleBeanDefinitionRegistry();
        CustomerClassPathBeanDefinitionScanner scanner = new CustomerClassPathBeanDefinitionScanner(beanDefinitionRegistry, false);
        TypeFilter typeFilter = (metadataReader, metadataReaderFactory) -> {
            if (metadataReader.getClassMetadata().isConcrete() && metadataReader.getAnnotationMetadata().hasAnnotation(TimerWheel.class.getName())) {
                return true;
            }
            return false;
        };
        scanner.addIncludeFilter(typeFilter);
        return scanner.doScan(basePackages);
    }

    @Override
    public int getOrder() {
        return Ordered.LOWEST_PRECEDENCE;
    }

    public String[] getBasePackages() {
        return basePackages;
    }

    public void setBasePackages(String[] basePackages) {
        this.basePackages = basePackages;
    }
}
    2)通过调用方法注册
       直接手动调用注册方法

3.2 时间轮

时间轮.png

时间轮的工作原理可以类比时钟,如上图箭头(指针)按某一个方向按固定频率轮动,每一次跳动称为一个 tick
时间轮主要有几个重要参数:tickDuration(每次跳动持续时间),ticksPerWheel(时间轮一轮的插槽数)
,currentTick(当前指针所在的插槽下标)
时间轮我们约定指针按照某个方向运行,当指针运行到某个插槽时,计算出当前插槽所代表的
时间戳,这个时间戳对应的任务已timeout, 当向时间轮投递任务时,先计算出执行时间和插槽位置
package com.leelen.cloud.timerwheel;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.leelen.cloud.entity.Response;
import com.leelen.cloud.entity.Slot;
import com.leelen.cloud.entity.Subcriber;
import com.leelen.cloud.utils.SpringContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * @version: 1.00.00
 * @description: 时间轮
 * @copyright: Copyright (c) 2021 立林科技 All Rights Reserved
 * @company: 厦门立林科技有限公司
 * @author: hj
 * @date: 2021-10-27 13:11
 */
@Slf4j
@Component
public class TimingWheel {


    private static final ThreadPoolExecutor SEND_MSG_SERVICE = new ThreadPoolExecutor(10, 20, 180L, TimeUnit.SECONDS, new ArrayBlockingQueue(10000), new ThreadPoolExecutor.DiscardOldestPolicy());


    /**
     * 每次跳动持续时间
     */
    @Value("${timeWheel.tickDuration:1}")
    private Long tickDuration;

    /**
     * 分发服务名
     */
    @Value("${timeWheel.dispatchServiceName:orderDispatch}")
    private String dispatchServiceName;
    /**
     * 时间轮一轮的tick数
     */
    @Value("${timeWheel.ticksPerWheel:60}")
    private Integer ticksPerWheel;

    /**
     * 当前指针所在的tick的下标
     */
    private volatile int currentTickIndex = 0;

    /**
     * 每一路的时间
     */
    private long perWheelTime;

    /**
     * 当前跳跃的次数
     */
    private volatile long tick;


    /**
     * 时间轮开始时间
     */
    private volatile long startTime;

    /**
     * 线程是否执行
     */
    private boolean isRunning = true;
    /**
     * 时间轮集合
     */
    private List<Slot> timerWheelList = Lists.newArrayList();

    /**
     * 时间轮线程
     */
    private Thread timerWheelThread;
    /**
     * 注册订阅者
     */
    private SubcriberRegister subcriberRegister;
    /**
     * 分发任务
     */
    private Dispatch dispatch;

    /**
     * 读写锁
     */
    private final ReadWriteLock lock = new ReentrantReadWriteLock();

    public SubcriberRegister getSubcriberRegister() {
        return subcriberRegister;
    }

    /**
     * 开启时间轮
     */
    public synchronized void startThread() {
        if (!timerWheelThread.isAlive()) {
            isRunning = true;
            timerWheelThread.start();
        }
    }

    /**
     * 停止时间轮
     */
    public synchronized void stopThread() {
        try {
            isRunning = false;
            if (timerWheelThread.isAlive()) {
                timerWheelThread.interrupt();
                timerWheelThread.join();
            }
        } catch (InterruptedException e) {
            log.error("时间轮停止失败", e);
        }
    }

    /**
     * 初始化
     */
    public void init() {
        tickDuration = TimeUnit.MILLISECONDS.convert(tickDuration, TimeUnit.SECONDS);
        perWheelTime = (ticksPerWheel + 1) * tickDuration;
        timerWheelThread = new Thread(new timerWheelRunnabe(), "timer-wheel");
        for (int i = 0; i < ticksPerWheel + 1; i++) {
            timerWheelList.add(new Slot(i));
        }
        subcriberRegister = SubcriberRegister.getInstance();
        startTime = System.currentTimeMillis();
        dispatch = SpringContextUtils.getBean(dispatchServiceName, Dispatch.class);
        tick = 0;
        timerWheelThread.start();
    }

    /**
     * 添加任务
     *
     * @param t
     * @param <T>
     * @return
     */
    public <T> Response addTask(T t, long delayTime, TimeUnit timeUnit) {
        Preconditions.checkNotNull(t);
        Preconditions.checkNotNull(timeUnit);
        delayTime = TimeUnit.MILLISECONDS.convert(delayTime, timeUnit);
        lock.writeLock().lock();
        int tick1 = -1;
        long executeTime = -1L;
        try {
            tick1 = (int) computeExecuteTime(delayTime);
            executeTime = startTime + (tick - 1) * tickDuration + delayTime;
            log.info("tick={}, executeTime={}, startTime={}", tick1, executeTime, startTime);
            Slot slot = timerWheelList.get(tick1);
            slot.add(t, executeTime);
        } catch (Exception e) {
            log.error("添加任务异常", e);
        } finally {
            lock.writeLock().unlock();
        }
        Response response = Response.builder()
                .executeTime(executeTime)
                .tick(tick1)
                .build();
        return response;
    }

    /**
     * 移除任务
     *
     * @param t
     * @param executeTime
     * @param tick
     * @param <T>
     */
    public <T> void removeTask(T t, long executeTime, int tick) {
        Slot slot = timerWheelList.get(tick);
        Preconditions.checkNotNull(slot);
        lock.writeLock().lock();
        try {
            slot.remove(executeTime, t);
        } catch (Exception e) {
            log.error("移除任务异常", e);
        } finally {
            lock.writeLock().unlock();
        }
    }

    /**
     * 查询某个插槽,某个时间点的任务
     *
     * @param executeTime
     * @param tick
     * @return
     */
    public Set<?> queryTask(Long executeTime, int tick) {
        Slot slot = timerWheelList.get(tick);
        lock.readLock().lock();
        Set set = Sets.newCopyOnWriteArraySet();
        try {
            set = slot.get(executeTime);
        } catch (Exception e) {
            log.error("查询任务异常", e);
        } finally {
            lock.readLock().unlock();
        }
        return set;
    }

    /**
     * 计算任务插槽
     *
     * @param delayTime
     * @return
     */
    private long computeExecuteTime(long delayTime) {
        long tick1 = delayTime % perWheelTime == 0 ? currentTickIndex : (delayTime % perWheelTime) / tickDuration + currentTickIndex;
        return tick1 % (perWheelTime / tickDuration);
    }

    /**
     * 通知所有订阅者
     */
    private void notifySubscriber(int currentTickIndex, long executeTime) {
        Slot slot = timerWheelList.get(currentTickIndex);
        CopyOnWriteArraySet copyOnWriteArraySet = slot.get(executeTime);
        slot.remove(executeTime, null);
        SEND_MSG_SERVICE.execute(() -> {
            if (!CollectionUtils.isEmpty(copyOnWriteArraySet)) {
                copyOnWriteArraySet.forEach(event -> {
                    CopyOnWriteArraySet<Subcriber> subcribers = subcriberRegister.getSubcribers(event.getClass());
                    this.dispatch.dispatchTask(event, subcribers);
                });
            }
        });
    }

    /**
     * 等待时间
     *
     * @param startTime
     * @return
     */
    private void waitForTime(long startTime, long tick) throws InterruptedException {
        long restoreTime = tick * tickDuration - (System.currentTimeMillis() - startTime);
        if (restoreTime <= 0) {
            return;
        }
        Thread.sleep(restoreTime);
    }


    public class timerWheelRunnabe implements Runnable {

        @Override
        public void run() {
            try {
                while (isRunning) {
                    if (Thread.currentThread().isInterrupted()) {
                        return;
                    }
//                    lock.writeLock().lock();
                    //通知对应订阅者
                    notifySubscriber(currentTickIndex, startTime + tick * tickDuration);
                    //等待执行时间
                    tick++;
                    waitForTime(startTime, tick);
                    if (currentTickIndex == ticksPerWheel) {
                        currentTickIndex = 0;
                    } else {
                        currentTickIndex++;
                    }
//                    lock.writeLock().unlock();
                }
            } catch (InterruptedException e) {
                log.error("InterruptedException:时间轮停止");
            } catch (Exception e) {
                log.error("时间轮运行异常", e);
            } finally {
                lock.writeLock().unlock();
            }
        }
    }


}

3.2 任务分发

1、按照任务类型+任务执行对象维度进行分发
   
2、按照添加的任务顺序进行分发
package com.leelen.cloud.timerwheel;

import com.leelen.cloud.entity.Subcriber;
import org.springframework.util.CollectionUtils;

import java.util.concurrent.CopyOnWriteArraySet;

/**
* @version: 1.00.00
* @description: 任务分发
* @copyright: Copyright (c) 2021 立林科技 All Rights Reserved
* @company: 厦门立林科技有限公司
* @author: hj
* @date: 2021-10-29 13:12
*/
public abstract class Dispatch {

   /**
    * 分发任务
    *
    * @param event
    * @param subcribers
    */
   public void dispatchTask(Object event, CopyOnWriteArraySet<Subcriber> subcribers) {
       if (!CollectionUtils.isEmpty(subcribers)) {
           for (Subcriber subcriber : subcribers) {
               executorWay(event, subcriber);
           }
       }
   }

   /**
    * 执行方式
    *
    * @param event
    * @param subcriber
    */
   public abstract void executorWay(Object event, Subcriber subcriber);
}

3.3 GitHub地址

时间轮