Springboot扩展-自定义实现分布式延时任务模块,Enable扩展-CSDN博客

83 阅读5分钟

文章目录

写在前面

目前市面上实现延时消息有比较多的成熟的中间件,RocketMQ、Kafka、RabbitMQ等等,这些中间件确实十分成熟,但是比较重。

本文手把手教你基于Java延时队列实现一个轻量级分布式延时消息,如有瑕疵还请大佬指正!

一、Enable模块

关于SpringBoot的Enable模块设计,请参考:
Spring注解驱动原理及源码,深入理解Spring注解驱动

1、定义注解入口

import org.springframework.context.annotation.Import;
import java.lang.annotation.*;

/**
 * 开启倒计时功能
 */
@Documented
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Import({CountDownDefinitionRegistrar.class})
public @interface EnableCountDown {

    String poolSizeString = "poolSize";

    // 线程数
    int poolSize() default 20;
}

在SpringBoot启动类上,加上@EnableCountDown就可以按需开启该功能,如不需要可以不开启,避免额外系统开销。

2、定义注册类


import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
import org.springframework.beans.factory.support.BeanDefinitionReaderUtils;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.annotation.AnnotationAttributes;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

/**
 * 延时任务注测
 */
public class CountDownDefinitionRegistrar implements ImportBeanDefinitionRegistrar {

    @Override
    public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, @NotNull BeanDefinitionRegistry registry) {
        AnnotationAttributes annotationAttributes = AnnotationAttributes.fromMap(importingClassMetadata.getAnnotationAttributes(EnableCountDown.class.getName()));
        if (annotationAttributes == null) {
            return;
        }
		// 获取注解属性值
        Number number = annotationAttributes.getNumber(EnableCountDown.poolSizeString);


        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        if(number.intValue() <= 0){
            number = 20;
        }
        //核心线程池大小
        executor.setCorePoolSize(number.intValue());
        //最大线程数
        executor.setMaxPoolSize(number.intValue());
        //活跃时间
        executor.setKeepAliveSeconds(60);
        //线程名字前缀
        executor.setThreadNamePrefix("countDown-");
        executor.initialize(); // 显式初始化

        GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
        beanDefinition.setBeanClass(CountDownBasic.class);
        ConstructorArgumentValues constructorArgumentValues = new ConstructorArgumentValues();
        constructorArgumentValues.addIndexedArgumentValue(0, executor);
        beanDefinition.setConstructorArgumentValues(constructorArgumentValues);
        registry.registerBeanDefinition(CountDownBasic.BEAN_NAME, beanDefinition);
    }
}

关于BeanDefinition的注册,请参考:深入浅出弄明白Spring的BeanDefinition,Spring-Beans部分源码分析

此处踩坑:
(1)ThreadPoolTaskExecutor必须显式调用initialize才能使用。
(2)通过编程方式注册的Bean,Idea是不认的,有强警告提示找不到该Bean,使用applicationContext.getBean(BEAN_NAME, CountDownBasic.class);静态方法可以完美解决。

二、延时任务模块

1、延时任务基础类

import com.task.context.task.delay.AbstractDelayTask;
import com.task.context.task.delay.DelayTaskManager;
import com.task.context.task.delay.DelayedTask;
import com.task.context.task.delay.TaskManager;
import com.task.context.task.model.TaskModel;
import oberon.commons.aide.JsonAide;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.BeansException;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.task.TaskExecutor;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * 延时任务基础类
 */
public class CountDownBasic implements ApplicationRunner, ApplicationContextAware {

    static final String BEAN_NAME = "countDownBasic";

    private static ApplicationContext applicationContext;

    public static CountDownBasic getInstance() {
        // 去警告
        return applicationContext.getBean(BEAN_NAME, CountDownBasic.class);
    }

    private final TaskExecutor taskExecutor;

    private final ScheduledExecutorService scheduledExecutorService = new ScheduledThreadPoolExecutor(1);

    private TaskManager taskManager = new DelayTaskManager();

    public CountDownBasic(TaskExecutor taskExecutor) {
        this.taskExecutor = taskExecutor;
    }

    public TaskExecutor getTaskExecutor() {
        return this.taskExecutor;
    }

    public ScheduledExecutorService getScheduledExecutorService() {
        return this.scheduledExecutorService;
    }

    public TaskManager getDelayTaskManager() {
        return this.taskManager;
    }

    public void executeTask(AbstractDelayTask<?> task){
        // 持久化
        saveData(task);
        getDelayTaskManager().setTask(task);
    }

    /**
     * 数据持久化
     */
    private void saveData(DelayedTask<?> task) {
        // TODO - 持久化到数据库
        TaskModel taskModel = new TaskModel();
        taskModel.setId(task.getTaskId());
        taskModel.setTaskName(task.getTaskName());
        taskModel.setEndTime(task.getEndTime());
        taskModel.setClazzName(task.getClass().getName());
        taskModel.setMethodName("execute");
        taskModel.setData(JsonAide.toJson(task.getData()));
        taskModel.setExecEnd(0);
        taskModel.setDataClazzName(task.getData().getClass().getName());

        System.out.println("持久化:" + JsonAide.toJson(taskModel));
    }

    /**
     * 开始执行任务
     */
    private void execTasks() {
        System.out.println("开始执行任务");
        // 执行任务
        getScheduledExecutorService().scheduleWithFixedDelay(() -> {
            AbstractDelayTask<?> task = getDelayTaskManager().getTask();
            System.out.println("取出任务" + task.getTaskId());
            this.taskExecutor.execute(task);
        }, 1, 1,  TimeUnit.MILLISECONDS);//(频率可以适当调节)
    }

    /**
     * 初始化所有任务
     */
    @SuppressWarnings("unchecked")
    private void initTasks() throws Exception {
        // TODO 从数据库获取任务列表
        List<TaskModel> taskModels = new ArrayList<>();
        TaskModel taskModel = JsonAide.fromJson("{\"id\":\"9c8524c16f654902944c3c43a372b7dc\",\"taskName\":\"测试任务\",\"endTime\":1665280410000,\"data\":\"{\\\"id\\\":\\\"ididid\\\",\\\"name\\\":\\\"zhangsan\\\"}\",\"clazzName\":\"com.task.context.task.model.TestTask\",\"methodName\":\"execute\",\"execEnd\":\"0\",\"dataClazzName\":\"com.task.context.task.model.TestData\"}", TaskModel.class);
        taskModels.add(taskModel);

        for (TaskModel model : taskModels) {
            // 数据
            Object o = JsonAide.fromJson(model.getData(), Class.forName(model.getDataClazzName()));
            // 任务
            AbstractDelayTask task = (AbstractDelayTask)Class.forName(model.getClazzName()).getDeclaredConstructor().newInstance();
            task.taskId(model.getId()).taskName(model.getTaskName()).endTime(model.getEndTime()).data(o);

            getDelayTaskManager().setTask(task);
        }
    }

    /**
     * 项目初始化完成后,初始化所有任务,开始获取任务
     */
    @Override
    public void run(ApplicationArguments args) throws Exception {
        // 初始化所有任务
        initTasks();
        // 开始执行任务
        execTasks();
    }

    @Override
    public void setApplicationContext(@NotNull ApplicationContext applicationContext) throws BeansException {
        CountDownBasic.applicationContext = applicationContext;
    }
}

这里数据持久化到数据库,还未完成,涉及数据库,其实也很简单,自行实现即可。

2、延时任务管理器

public interface TaskManager {

    //设置任务
    void setTask(AbstractDelayTask<?> t);

    // 获取任务
    AbstractDelayTask<?> getTask();
}
import java.util.concurrent.DelayQueue;

public class DelayTaskManager implements TaskManager {

    //唯一延时队列
    private static final DelayQueue<AbstractDelayTask<?>> queue = new DelayQueue<>();
    //设置任务
    public void setTask(AbstractDelayTask<?> t){
        queue.add(t);
    }

    // 获取任务
    public AbstractDelayTask<?> getTask() {
        try {
            return queue.take();
        } catch (InterruptedException e) {
            // TODO exception
            e.printStackTrace();
            Thread.currentThread().interrupt();
        }
        return null;
    }
}

3、延时任务类


import oberon.commons.aide.UuidAide;
import oberon.equipment.spring.clock.ClockHelper;

import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;

public class DelayedTask<T> implements Delayed{
    //唯一标识
    private String taskId;
    //任务名称
    private String taskName;
    //延时到多久
    private Long endTime;
    // 数据
    private T data;
    //定义时间工具类
    public static final TimeUnit timeUnit = TimeUnit.SECONDS;

    /**
     * 默认生成id、name、60秒倒计时
     */
    public DelayedTask() {
        this.taskId = UuidAide.withoutSeparator();
        this.taskName = Thread.currentThread().getName();
        this.endTime = ClockHelper.timeMillis() + 60 * 1000L;
    }

    public DelayedTask<T> taskId(String taskId){
        this.taskId = taskId;
        return this;
    }
    public DelayedTask<T> taskName(String taskName){
        this.taskName = taskName;
        return this;
    }
    // 指定延时到多久
    public DelayedTask<T> endTime(Long endTime){
        this.endTime = endTime;
        return this;
    }
    public DelayedTask<T> data(T data){
        this.data = data;
        return this;
    }
    // 指定延时时间
    public DelayedTask<T> delayTime(Integer delayTime){
        this.endTime = ClockHelper.timeMillis() + delayTime * 1000;
        return this;
    }

    public String getTaskId() {
        return taskId;
    }

    public String getTaskName() {
        return taskName;
    }

    public Long getEndTime() {
        return endTime;
    }

    public T getData() {
        return data;
    }


    /**
     * 用来判断是否到了截止时间
     */
    @Override
    public long getDelay(TimeUnit unit) {
        return unit.convert(endTime, TimeUnit.MILLISECONDS) - unit.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        //return endTime - System.currentTimeMillis();
    }

    /**
     * 相互批较排序用
     */
    @Override
    public int compareTo(Delayed o) {
        // BaseTask o1 = (BaseTask)o;
        return (int) (this.getDelay(TimeUnit.MILLISECONDS) - o.getDelay(TimeUnit.MILLISECONDS));
        //return this.getDelay(this.timeUnit) - t.getDelay(this.timeUnit) > 0 ? 1:0;
    }
}

import com.task.context.task.config.SpringUtil;
import oberon.equipment.spring.clock.ClockHelper;
import org.springframework.data.redis.core.StringRedisTemplate;

import java.util.concurrent.TimeUnit;

/**
 * 延时任务类
 */
public abstract class AbstractDelayTask<T> extends DelayedTask<T> implements Runnable {


    // 异步执行任务
    @Override
    public void run() {
        System.out.println("准备执行延时任务:" + getTaskId() + "名称为:" + getTaskName());
        // 支持分布式+集群
        // TODO 数据库 + redis
        StringRedisTemplate redisTemplate = SpringUtil.getBean("mainRedis", StringRedisTemplate.class);
        if(Boolean.FALSE.equals(redisTemplate.opsForValue().setIfAbsent(getTaskId(), String.valueOf(ClockHelper.timeMillis()), 1, TimeUnit.MINUTES))) {
            return;
        }
        System.out.println("开始执行延时任务:" + getTaskId() + "名称为:" + getTaskName());
        try {
            // TODO 判断数据库是否已执行过

            if(execute(getData())){
                // TODO 数据库中修改为已完成
            }
            System.out.println("延时任务执行结束:" + getTaskId() + "名称为:" + getTaskName());
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            redisTemplate.delete(getTaskId());
        }
    }
 
    /**
     * 具体执行的任务,需要实现该方法
     * ack:返回true为执行成功,返回false为执行失败
     */
    public abstract Boolean execute(T data);
 
}

延时任务主类,使用模板模式,这里是基于redis支持多集群互斥尚未完成,自行实现即可。

三、工具类

1、Spring上下文工具类

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
 
/**
 * 提供手动获取被spring管理的bean对象
 */
@Component
public class SpringUtil implements ApplicationContextAware {
	
	private static ApplicationContext applicationContext;

	@Override
	public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
		if (SpringUtil.applicationContext == null) {
			SpringUtil.applicationContext = applicationContext;
		}
	}
 
	// 获取applicationContext
	public static ApplicationContext getApplicationContext() {
		return applicationContext;
	}
 
	// 通过name获取 Bean.
	public static Object getBean(String name) {
		return getApplicationContext().getBean(name);
	}
 
	// 通过class获取Bean.
	public static <T> T getBean(Class<T> clazz) {
		return getApplicationContext().getBean(clazz);
	}
 
	// 通过name,以及Clazz返回指定的Bean
	public static <T> T getBean(String name, Class<T> clazz) {
		return getApplicationContext().getBean(name, clazz);
	}
 
}

四、任务模块

1、延时任务模型


/**
 * 任务类
 */
public class TaskModel {

    //唯一标识
    private String id;
    //任务名称
    private String taskName;
    //延时到多久
    private Long endTime;
    // 数据JSON
    private String data;
    // 类全限定名
    private String clazzName;
    // 方法名
    private String methodName;
    // 是否执行完毕 1完毕,0未执行
    private Integer execEnd;
    // data数据类型
    private String dataClazzName;

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public String getTaskName() {
        return taskName;
    }

    public void setTaskName(String taskName) {
        this.taskName = taskName;
    }

    public Long getEndTime() {
        return endTime;
    }

    public void setEndTime(Long endTime) {
        this.endTime = endTime;
    }

    public String getData() {
        return data;
    }

    public void setData(String data) {
        this.data = data;
    }

    public String getClazzName() {
        return clazzName;
    }

    public void setClazzName(String clazzName) {
        this.clazzName = clazzName;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Integer getExecEnd() {
        return execEnd;
    }

    public void setExecEnd(Integer execEnd) {
        this.execEnd = execEnd;
    }

    public String getDataClazzName() {
        return dataClazzName;
    }

    public void setDataClazzName(String dataClazzName) {
        this.dataClazzName = dataClazzName;
    }
}

该任务类对应数据库model,持久化数据库的字段可以参考该类。

五、使用

1、测试任务类

import com.task.context.task.delay.AbstractDelayTask;

import java.util.Date;

/**
 * 测试
 */
public class TestTask extends AbstractDelayTask<TestData> {

    @Override
    public Boolean execute(TestData data) {
        // 自行实现编程式事务
        System.out.println(new Date() + data.getId() + data.getName());

        return true;
    }
}

2、测试数据类


/**
 * 测试
 */
public class TestData {

    private String id;

    private String name;

    public TestData(String id, String name) {
        this.id = id;
        this.name = name;
    }

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }
}

3、编程实现看看效果

@RestController
@RequestMapping("task")
public class TaskController {

    @RequestMapping("/test")
    public String test(){
        System.out.println(new Date() + "创建延时任务,延时30秒");
        TestData testData = new TestData("ididid", "zhangsan");
        AbstractDelayTask<TestData> testTask = new TestTask();
        testTask.taskName("测试任务")
                .delayTime(10) // 延时30秒执行
                .data(testData);
        CountDownBasic.getInstance().executeTask(testTask);
        return "success";
    }

}
持久化:{"id":"5b22451d15c1474485fd7e2d4bde8f05","taskName":"测试任务","endTime":1675755575590,"data":"{\"id\":\"ididid\",\"name\":\"zhangsan\"}","clazzName":"com.task.context.task.model.TestTask","methodName":"execute","execEnd":0,"dataClazzName":"com.task.context.task.model.TestData"}
取出任务5b22451d15c1474485fd7e2d4bde8f05
准备执行延时任务:5b22451d15c1474485fd7e2d4bde8f05名称为:测试任务
开始执行延时任务:5b22451d15c1474485fd7e2d4bde8f05名称为:测试任务
Tue Feb 07 15:40:20 CST 2023idididzhangsan
延时任务执行结束:5b22451d15c1474485fd7e2d4bde8f05名称为:测试任务