文章目录
写在前面
目前市面上实现延时消息有比较多的成熟的中间件,RocketMQ、Kafka、RabbitMQ等等,这些中间件确实十分成熟,但是比较重。
本文手把手教你基于Java延时队列实现一个轻量级分布式延时消息,如有瑕疵还请大佬指正!
一、Enable模块
关于SpringBoot的Enable模块设计,请参考:
Spring注解驱动原理及源码,深入理解Spring注解驱动
1、定义注解入口
import org.springframework.context.annotation.Import;
import java.lang.annotation.*;
/**
* 开启倒计时功能
*/
@Documented
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Import({CountDownDefinitionRegistrar.class})
public @interface EnableCountDown {
String poolSizeString = "poolSize";
// 线程数
int poolSize() default 20;
}
在SpringBoot启动类上,加上@EnableCountDown就可以按需开启该功能,如不需要可以不开启,避免额外系统开销。
2、定义注册类
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
import org.springframework.beans.factory.support.BeanDefinitionReaderUtils;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.annotation.AnnotationAttributes;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
/**
* 延时任务注测
*/
public class CountDownDefinitionRegistrar implements ImportBeanDefinitionRegistrar {
@Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, @NotNull BeanDefinitionRegistry registry) {
AnnotationAttributes annotationAttributes = AnnotationAttributes.fromMap(importingClassMetadata.getAnnotationAttributes(EnableCountDown.class.getName()));
if (annotationAttributes == null) {
return;
}
// 获取注解属性值
Number number = annotationAttributes.getNumber(EnableCountDown.poolSizeString);
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
if(number.intValue() <= 0){
number = 20;
}
//核心线程池大小
executor.setCorePoolSize(number.intValue());
//最大线程数
executor.setMaxPoolSize(number.intValue());
//活跃时间
executor.setKeepAliveSeconds(60);
//线程名字前缀
executor.setThreadNamePrefix("countDown-");
executor.initialize(); // 显式初始化
GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
beanDefinition.setBeanClass(CountDownBasic.class);
ConstructorArgumentValues constructorArgumentValues = new ConstructorArgumentValues();
constructorArgumentValues.addIndexedArgumentValue(0, executor);
beanDefinition.setConstructorArgumentValues(constructorArgumentValues);
registry.registerBeanDefinition(CountDownBasic.BEAN_NAME, beanDefinition);
}
}
关于BeanDefinition的注册,请参考:深入浅出弄明白Spring的BeanDefinition,Spring-Beans部分源码分析
此处踩坑:
(1)ThreadPoolTaskExecutor必须显式调用initialize才能使用。
(2)通过编程方式注册的Bean,Idea是不认的,有强警告提示找不到该Bean,使用applicationContext.getBean(BEAN_NAME, CountDownBasic.class);静态方法可以完美解决。
二、延时任务模块
1、延时任务基础类
import com.task.context.task.delay.AbstractDelayTask;
import com.task.context.task.delay.DelayTaskManager;
import com.task.context.task.delay.DelayedTask;
import com.task.context.task.delay.TaskManager;
import com.task.context.task.model.TaskModel;
import oberon.commons.aide.JsonAide;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.BeansException;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.task.TaskExecutor;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* 延时任务基础类
*/
public class CountDownBasic implements ApplicationRunner, ApplicationContextAware {
static final String BEAN_NAME = "countDownBasic";
private static ApplicationContext applicationContext;
public static CountDownBasic getInstance() {
// 去警告
return applicationContext.getBean(BEAN_NAME, CountDownBasic.class);
}
private final TaskExecutor taskExecutor;
private final ScheduledExecutorService scheduledExecutorService = new ScheduledThreadPoolExecutor(1);
private TaskManager taskManager = new DelayTaskManager();
public CountDownBasic(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
}
public TaskExecutor getTaskExecutor() {
return this.taskExecutor;
}
public ScheduledExecutorService getScheduledExecutorService() {
return this.scheduledExecutorService;
}
public TaskManager getDelayTaskManager() {
return this.taskManager;
}
public void executeTask(AbstractDelayTask<?> task){
// 持久化
saveData(task);
getDelayTaskManager().setTask(task);
}
/**
* 数据持久化
*/
private void saveData(DelayedTask<?> task) {
// TODO - 持久化到数据库
TaskModel taskModel = new TaskModel();
taskModel.setId(task.getTaskId());
taskModel.setTaskName(task.getTaskName());
taskModel.setEndTime(task.getEndTime());
taskModel.setClazzName(task.getClass().getName());
taskModel.setMethodName("execute");
taskModel.setData(JsonAide.toJson(task.getData()));
taskModel.setExecEnd(0);
taskModel.setDataClazzName(task.getData().getClass().getName());
System.out.println("持久化:" + JsonAide.toJson(taskModel));
}
/**
* 开始执行任务
*/
private void execTasks() {
System.out.println("开始执行任务");
// 执行任务
getScheduledExecutorService().scheduleWithFixedDelay(() -> {
AbstractDelayTask<?> task = getDelayTaskManager().getTask();
System.out.println("取出任务" + task.getTaskId());
this.taskExecutor.execute(task);
}, 1, 1, TimeUnit.MILLISECONDS);//(频率可以适当调节)
}
/**
* 初始化所有任务
*/
@SuppressWarnings("unchecked")
private void initTasks() throws Exception {
// TODO 从数据库获取任务列表
List<TaskModel> taskModels = new ArrayList<>();
TaskModel taskModel = JsonAide.fromJson("{\"id\":\"9c8524c16f654902944c3c43a372b7dc\",\"taskName\":\"测试任务\",\"endTime\":1665280410000,\"data\":\"{\\\"id\\\":\\\"ididid\\\",\\\"name\\\":\\\"zhangsan\\\"}\",\"clazzName\":\"com.task.context.task.model.TestTask\",\"methodName\":\"execute\",\"execEnd\":\"0\",\"dataClazzName\":\"com.task.context.task.model.TestData\"}", TaskModel.class);
taskModels.add(taskModel);
for (TaskModel model : taskModels) {
// 数据
Object o = JsonAide.fromJson(model.getData(), Class.forName(model.getDataClazzName()));
// 任务
AbstractDelayTask task = (AbstractDelayTask)Class.forName(model.getClazzName()).getDeclaredConstructor().newInstance();
task.taskId(model.getId()).taskName(model.getTaskName()).endTime(model.getEndTime()).data(o);
getDelayTaskManager().setTask(task);
}
}
/**
* 项目初始化完成后,初始化所有任务,开始获取任务
*/
@Override
public void run(ApplicationArguments args) throws Exception {
// 初始化所有任务
initTasks();
// 开始执行任务
execTasks();
}
@Override
public void setApplicationContext(@NotNull ApplicationContext applicationContext) throws BeansException {
CountDownBasic.applicationContext = applicationContext;
}
}
这里数据持久化到数据库,还未完成,涉及数据库,其实也很简单,自行实现即可。
2、延时任务管理器
public interface TaskManager {
//设置任务
void setTask(AbstractDelayTask<?> t);
// 获取任务
AbstractDelayTask<?> getTask();
}
import java.util.concurrent.DelayQueue;
public class DelayTaskManager implements TaskManager {
//唯一延时队列
private static final DelayQueue<AbstractDelayTask<?>> queue = new DelayQueue<>();
//设置任务
public void setTask(AbstractDelayTask<?> t){
queue.add(t);
}
// 获取任务
public AbstractDelayTask<?> getTask() {
try {
return queue.take();
} catch (InterruptedException e) {
// TODO exception
e.printStackTrace();
Thread.currentThread().interrupt();
}
return null;
}
}
3、延时任务类
import oberon.commons.aide.UuidAide;
import oberon.equipment.spring.clock.ClockHelper;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
public class DelayedTask<T> implements Delayed{
//唯一标识
private String taskId;
//任务名称
private String taskName;
//延时到多久
private Long endTime;
// 数据
private T data;
//定义时间工具类
public static final TimeUnit timeUnit = TimeUnit.SECONDS;
/**
* 默认生成id、name、60秒倒计时
*/
public DelayedTask() {
this.taskId = UuidAide.withoutSeparator();
this.taskName = Thread.currentThread().getName();
this.endTime = ClockHelper.timeMillis() + 60 * 1000L;
}
public DelayedTask<T> taskId(String taskId){
this.taskId = taskId;
return this;
}
public DelayedTask<T> taskName(String taskName){
this.taskName = taskName;
return this;
}
// 指定延时到多久
public DelayedTask<T> endTime(Long endTime){
this.endTime = endTime;
return this;
}
public DelayedTask<T> data(T data){
this.data = data;
return this;
}
// 指定延时时间
public DelayedTask<T> delayTime(Integer delayTime){
this.endTime = ClockHelper.timeMillis() + delayTime * 1000;
return this;
}
public String getTaskId() {
return taskId;
}
public String getTaskName() {
return taskName;
}
public Long getEndTime() {
return endTime;
}
public T getData() {
return data;
}
/**
* 用来判断是否到了截止时间
*/
@Override
public long getDelay(TimeUnit unit) {
return unit.convert(endTime, TimeUnit.MILLISECONDS) - unit.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
//return endTime - System.currentTimeMillis();
}
/**
* 相互批较排序用
*/
@Override
public int compareTo(Delayed o) {
// BaseTask o1 = (BaseTask)o;
return (int) (this.getDelay(TimeUnit.MILLISECONDS) - o.getDelay(TimeUnit.MILLISECONDS));
//return this.getDelay(this.timeUnit) - t.getDelay(this.timeUnit) > 0 ? 1:0;
}
}
import com.task.context.task.config.SpringUtil;
import oberon.equipment.spring.clock.ClockHelper;
import org.springframework.data.redis.core.StringRedisTemplate;
import java.util.concurrent.TimeUnit;
/**
* 延时任务类
*/
public abstract class AbstractDelayTask<T> extends DelayedTask<T> implements Runnable {
// 异步执行任务
@Override
public void run() {
System.out.println("准备执行延时任务:" + getTaskId() + "名称为:" + getTaskName());
// 支持分布式+集群
// TODO 数据库 + redis
StringRedisTemplate redisTemplate = SpringUtil.getBean("mainRedis", StringRedisTemplate.class);
if(Boolean.FALSE.equals(redisTemplate.opsForValue().setIfAbsent(getTaskId(), String.valueOf(ClockHelper.timeMillis()), 1, TimeUnit.MINUTES))) {
return;
}
System.out.println("开始执行延时任务:" + getTaskId() + "名称为:" + getTaskName());
try {
// TODO 判断数据库是否已执行过
if(execute(getData())){
// TODO 数据库中修改为已完成
}
System.out.println("延时任务执行结束:" + getTaskId() + "名称为:" + getTaskName());
} catch (Exception e) {
e.printStackTrace();
} finally {
redisTemplate.delete(getTaskId());
}
}
/**
* 具体执行的任务,需要实现该方法
* ack:返回true为执行成功,返回false为执行失败
*/
public abstract Boolean execute(T data);
}
延时任务主类,使用模板模式,这里是基于redis支持多集群互斥尚未完成,自行实现即可。
三、工具类
1、Spring上下文工具类
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
/**
* 提供手动获取被spring管理的bean对象
*/
@Component
public class SpringUtil implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
if (SpringUtil.applicationContext == null) {
SpringUtil.applicationContext = applicationContext;
}
}
// 获取applicationContext
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
// 通过name获取 Bean.
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
// 通过class获取Bean.
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
// 通过name,以及Clazz返回指定的Bean
public static <T> T getBean(String name, Class<T> clazz) {
return getApplicationContext().getBean(name, clazz);
}
}
四、任务模块
1、延时任务模型
/**
* 任务类
*/
public class TaskModel {
//唯一标识
private String id;
//任务名称
private String taskName;
//延时到多久
private Long endTime;
// 数据JSON
private String data;
// 类全限定名
private String clazzName;
// 方法名
private String methodName;
// 是否执行完毕 1完毕,0未执行
private Integer execEnd;
// data数据类型
private String dataClazzName;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getTaskName() {
return taskName;
}
public void setTaskName(String taskName) {
this.taskName = taskName;
}
public Long getEndTime() {
return endTime;
}
public void setEndTime(Long endTime) {
this.endTime = endTime;
}
public String getData() {
return data;
}
public void setData(String data) {
this.data = data;
}
public String getClazzName() {
return clazzName;
}
public void setClazzName(String clazzName) {
this.clazzName = clazzName;
}
public String getMethodName() {
return methodName;
}
public void setMethodName(String methodName) {
this.methodName = methodName;
}
public Integer getExecEnd() {
return execEnd;
}
public void setExecEnd(Integer execEnd) {
this.execEnd = execEnd;
}
public String getDataClazzName() {
return dataClazzName;
}
public void setDataClazzName(String dataClazzName) {
this.dataClazzName = dataClazzName;
}
}
该任务类对应数据库model,持久化数据库的字段可以参考该类。
五、使用
1、测试任务类
import com.task.context.task.delay.AbstractDelayTask;
import java.util.Date;
/**
* 测试
*/
public class TestTask extends AbstractDelayTask<TestData> {
@Override
public Boolean execute(TestData data) {
// 自行实现编程式事务
System.out.println(new Date() + data.getId() + data.getName());
return true;
}
}
2、测试数据类
/**
* 测试
*/
public class TestData {
private String id;
private String name;
public TestData(String id, String name) {
this.id = id;
this.name = name;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
}
3、编程实现看看效果
@RestController
@RequestMapping("task")
public class TaskController {
@RequestMapping("/test")
public String test(){
System.out.println(new Date() + "创建延时任务,延时30秒");
TestData testData = new TestData("ididid", "zhangsan");
AbstractDelayTask<TestData> testTask = new TestTask();
testTask.taskName("测试任务")
.delayTime(10) // 延时30秒执行
.data(testData);
CountDownBasic.getInstance().executeTask(testTask);
return "success";
}
}
持久化:{"id":"5b22451d15c1474485fd7e2d4bde8f05","taskName":"测试任务","endTime":1675755575590,"data":"{\"id\":\"ididid\",\"name\":\"zhangsan\"}","clazzName":"com.task.context.task.model.TestTask","methodName":"execute","execEnd":0,"dataClazzName":"com.task.context.task.model.TestData"}
取出任务5b22451d15c1474485fd7e2d4bde8f05
准备执行延时任务:5b22451d15c1474485fd7e2d4bde8f05名称为:测试任务
开始执行延时任务:5b22451d15c1474485fd7e2d4bde8f05名称为:测试任务
Tue Feb 07 15:40:20 CST 2023idididzhangsan
延时任务执行结束:5b22451d15c1474485fd7e2d4bde8f05名称为:测试任务