接口防刷-Java基于注解实现接口防刷

70 阅读3分钟

-原文链接:blog.csdn.net/qq_47614329…

流程原理: 定义一切面,通过@Prevent注解作为切入点、在该切面的前置通知获取该方法的所有入参并将其Base64编码,将入参Base64编码+完整方法名+(ip或者用户等)作为redis的key,入参作为reids的value,@Prevent的value作为redis的expire,存入redis; 每次进来这个切面根据入参Base64编码+完整方法名判断redis值是否存在,存在则拦截防刷,不存在则允许调用;

一、定义注解

package com.*.web.annotation;

import java.lang.annotation.*;
 
/**
 * 接口防刷注解
 */
@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Prevent {
 
    /**
     * 限制的时间值(秒)默认60s
     */
    long value() default 60;

    /**
     * 限制规定时间内访问次数,默认只能访问一次
     */
    long times() default 1;
 
    /**
     * 提示
     */
    String message() default "";
 
    /**
     * 策略
     */
    PreventStrategy strategy() default PreventStrategy.DEFAULT;
}

二、防刷策略

public enum PreventStrategy {

    DEFAULT

}

三、定义切面

package com.zzyt.web.aop;

import com.*.web.annotation.Prevent;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import com.*.web.constants.PreventStrategy;

/**
 * 防刷切面实现类
 */
@Aspect
@Component
public class PreventAop {

    @Resource
    private RedisTemplate<String, Long> redisTemplate;

    final String redisKey = "PREVENT_METHOD_NAME:";


    /**
     * 切入点
     */
    @Pointcut("@annotation(com.zzyt.web.annotation.Prevent)")
    public void pointcut() {}


    /**
     * 处理前
     */
    @Before("pointcut()")
    public void joinPoint(JoinPoint joinPoint) throws Exception {
        // 获取调用者ip
        RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes();
        HttpServletRequest httpServletRequest = ((ServletRequestAttributes) requestAttributes).getRequest();
//        String userIP = IpUtils.getUserIP(httpServletRequest);
        String userIP = httpServletRequest.getRemoteHost();
        // 获取调用接口方法名
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        Method method = joinPoint.getTarget().getClass().getMethod(
                methodSignature.getName(),
                methodSignature.getParameterTypes()); // 获取该接口方法
        String methodFullName = method.getDeclaringClass().getName() + method.getName(); // 获取到方法名
        Prevent preventAnnotation = method.getAnnotation(Prevent.class); // 获取该接口上的prevent注解(为了使用该注解内的参数)
        // 执行对应策略
        entrance(preventAnnotation, userIP, methodFullName);
    }


    /**
     * 通过prevent注册判断执行策略
     * @param prevent 该接口的prevent注解对象
     * @param userIP 访问该接口的用户ip
     * @param methodFullName 该接口方法名
     */
    private void entrance(Prevent prevent, String userIP, String methodFullName) throws Exception {
        PreventStrategy strategy = prevent.strategy(); // 获取校验策略
        if (Objects.requireNonNull(strategy) == PreventStrategy.DEFAULT) { // 默认就是default策略,执行default策略方法
            defaultHandle(userIP, prevent, methodFullName);
        } else {
            throw new WrongDataException("无效的策略");
        }
    }


    /**
     * Default测试执行方法
     * @param userIP 访问该接口的用户ip
     * @param prevent 该接口的prevent注解对象
     * @param methodFullName 该接口方法名
     */
    private void defaultHandle(String userIP, Prevent prevent, String methodFullName) throws Exception {
        String base64StrIP = toBase64String(userIP); // 加密用户ip(避免ip存在一些特殊字符作为redis的key不合法)
        long expire = prevent.value(); // 获取访问限制时间
        long times = prevent.times(); // 获取访问限制次数
        // 限制特定时间内访问特定次数
        long count = redisTemplate.opsForValue().increment(
                redisKey + base64StrIP + ":" + methodFullName, 1); // 访问次数+1
        if (count == 1) { // 如果访问次数为1,则重置访问限制时间(即redis超时时间)
            redisTemplate.expire(
                    redisKey + base64StrIP + ":" + methodFullName,
                    expire,
                    TimeUnit.SECONDS);
        }
        if (count > times) { // 如果访问次数超出访问限制次数,则禁止访问
            // 如果有限制信息则使用限制信息,没有则使用默认限制信息
            String errorMessage =
                    !StringUtils.isEmpty(prevent.message()) ? prevent.message() : expire + "秒内不允许重复请求";
            throw new WrongDataException( errorMessage);

        }
    }


    /**
     * 对象转换为base64字符串
     * @param obj 对象值
     * @return base64字符串
     */
    private String toBase64String(String obj) throws Exception {
        if (StringUtils.isEmpty(obj)) {
            return null;
        }
        Base64.Encoder encoder = Base64.getEncoder();
        byte[] bytes = obj.getBytes(StandardCharsets.UTF_8);
        return encoder.encodeToString(bytes);
    }
}

四、测试

package com.*.web.controller;
@RestController
@RequestMapping("/warningNotification")
public class TestController {
    @Prevent(value = 60,times = 5)
    @PostMapping("/test")
    public ResponseData test(HttpServletRequest request) {
        return ResponseData.success();
    }
}