【学习笔记】使用切面实现数据权限

442 阅读2分钟

1 前言

  业务系统中,经常需要使用数据权限来过滤用户能看到的数据。比如用户只能看到自己创建的数据以及部门的数据,不能看到其他部门的数据(本文例子),实现该功能则需要所有业务表都有通用字段:创建人(create_user)、创建部门(create_dept)等信息。

  本文的持久层框架:mybatis-plus

2 代码实现

2.1 数据权限信息和缓存

@Data
@Builder
public class DataPermissionInfo {

    /**
     * 用户id
     */
    private List<Long> userIds;

    /**
     * 部门id
     */
    private List<Long> deptIds;

    /**
     * 是否拥有全部数据权限
     */
    private Boolean haveAll;

    /**
     * 需要过滤数据权限的sql
     * 该数据结构为字典树,在这里可理解为 Map<String, String>
     */
    private PrefixTrie<String> sql;

    /**
     * 是否全部sql生效
     */
    private Boolean allSql;

    /**
     * 全部sql生效时的表别名
     */
    private String defaultTableAlias;

    /**
     * 只有一条sql生效时,使用该字段
     */
    private String[] onlySql;

    /**
     * 生效sql的数量
     */
    private Integer count;
}
public class DataPermissionThreadLocal {

    private static final ThreadLocal<DataPermissionInfo> DATA_PERMISSION_CACHE = new ThreadLocal<>();

    public static void clear() {
        DATA_PERMISSION_CACHE.remove();
    }

    public static void set(DataPermissionInfo info) {
        DATA_PERMISSION_CACHE.set(info);
    }

    public static DataPermissionInfo get() {
        return DATA_PERMISSION_CACHE.get();
    }

}

2.2 自定义注解

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface DataPermission {

    /**
     * 需要过滤数据权限的sql,有时在某些业务方法中有多个查询,但是只有部分查询需要过滤数据权限
     */
    String sql() default "";

    /**
     * 表别名
     */
    String tableAlias() default "";
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface DataPermissionGroup {

    DataPermission[] permissionSql() default {};

    boolean allSql() default false;

    String defaultTableAlias() default "";
}

2.3 注解切面

package org.springblade.common.permission;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springblade.common.utils.PrefixTrie;
import org.springblade.core.secure.BladeUser;
import org.springblade.core.secure.utils.AuthUtil;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

@Aspect
@Component
public class DataPermissionAspect {

    private static final Logger LOGGER = LoggerFactory.getLogger(DataPermissionAspect.class);

    @Pointcut("@annotation(org.springblade.common.permission.DataPermission)")
    public void pointcut() {
    }

    @Pointcut("@annotation(org.springblade.common.permission.DataPermissionGroup)")
    public void pointcutForGroup() {
    }

    @Around("pointcut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();

        DataPermission dataPermission = method.getAnnotation(DataPermission.class);
        if (Objects.nonNull(dataPermission)) {
            DataPermissionInfo info = getDataPermissionInfo(new DataPermission[]{dataPermission}, false, "");
            DataPermissionThreadLocal.set(info);
        }
        // 执行方法
        Object result;
        try {
            result = joinPoint.proceed();
        } finally {
            // 清空缓存
            DataPermissionThreadLocal.clear();
        }
        return result;
    }

    @Around("pointcutForGroup()")
    public Object aroundForGroup(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();

        DataPermissionGroup dataPermissionGroup = method.getAnnotation(DataPermissionGroup.class);
        if (Objects.nonNull(dataPermissionGroup)) {
            DataPermissionInfo info = getDataPermissionInfo(dataPermissionGroup.permissionSql(),
                    dataPermissionGroup.allSql(), dataPermissionGroup.defaultTableAlias());
            DataPermissionThreadLocal.set(info);
        }
        // 执行方法
        Object result;
        try {
            result = joinPoint.proceed();
        } finally {
            // 清空缓存
            DataPermissionThreadLocal.clear();
        }
        return result;
    }

    private DataPermissionInfo getDataPermissionInfo(DataPermission[] dataPermissionList, boolean allSql, String defaultTableAlias) {
        BladeUser user = AuthUtil.getUser();
        // if (user.getRoleName().contains("超级管理员")) {
        //     return DataPermissionInfo.builder().haveAll(true).build();
        // }
        // TODO 查询数据库得到当前用户的授权部门,以及他自己的ID
        List<Long> deptIds = new ArrayList<>();
        Long curUserId = user.getUserId();

        DataPermissionInfo.DataPermissionInfoBuilder builder = DataPermissionInfo.builder().haveAll(false)
                .deptIds(deptIds).userIds(Collections.singletonList(curUserId));
        // 全部sql都要过滤
        if (allSql) {
            return builder.allSql(true).defaultTableAlias(defaultTableAlias).build();
        }
        if (dataPermissionList.length == 0) {
            return null;
        }
        builder = builder.allSql(false).count(dataPermissionList.length);
        if (dataPermissionList.length == 1) {
            DataPermission dataPermission = dataPermissionList[0];
            return builder.onlySql(new String[]{dataPermission.sql(), dataPermission.tableAlias()}).build();
        } else {
            PrefixTrie<String> sql = new PrefixTrie<>();
            for (DataPermission dataPermission : dataPermissionList) {
                sql.addWord(dataPermission.sql(), dataPermission.tableAlias());
            }
            return builder.sql(sql).build();
        }
    }
}

2.4 数据权限处理器

package org.springblade.common.permission;

import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.HexValue;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public class DefaultDataPermissionHandler implements DataPermissionHandler {

    private static final Logger LOGGER = LoggerFactory.getLogger(DefaultDataPermissionHandler.class);

    @Override
    public Expression getSqlSegment(Expression where, String mappedStatementId) {
        // 没加权限注解,则不会缓存相关信息,此时直接返回原始查询条件
        // 拥有全部权限时,也是直接返回原始查询条件
        DataPermissionInfo info = DataPermissionThreadLocal.get();
        if (Objects.isNull(info) || BooleanUtils.isTrue(info.getHaveAll())) {
            return where;
        }
        if (where == null) {
            where = new HexValue(" 1 = 1 ");
        }
        String tableAlias;
        if (Objects.nonNull(tableAlias = getTableAlias(info, mappedStatementId))) {
            String table = StringUtils.isEmpty(tableAlias) ? "" : tableAlias + ".";

            String deptColumn = table + "create_dept";
            Expression dept = createExpression(deptColumn, info.getDeptIds());

            String userColumn = table + "create_user";
            Expression user = createExpression(userColumn, info.getUserIds());
            // 最终效果如下:
            // where 原始条件 and (create_dept in (...) or create_user in (...))
            return new AndExpression(where, new OrExpressionExtend(dept, user));
        }
        return where;
    }

    private String getTableAlias(DataPermissionInfo info, String checkSql) {
        if (BooleanUtils.isTrue(info.getAllSql())) {
            return info.getDefaultTableAlias();
        }
        if (info.getCount() == 1) {
            String[] onlySql = info.getOnlySql();
            if (onlySql[0].equals(checkSql)) {
                return onlySql[1];
            }
        } else if (info.getCount() > 1) {
            return info.getSql().get(checkSql);
        }
        return null;
    }

    private Expression createExpression(String column, List<Long> ids) {
        if (ids.size() > 1) {
            InExpression inExpression = new InExpression();
            inExpression.setLeftExpression(new Column(column));
            ItemsList userList = new ExpressionList(ids.stream().filter(Objects::nonNull)
                    .map(LongValue::new).collect(Collectors.toList()));
            inExpression.setRightItemsList(userList);
            return inExpression;
        } else {
            EqualsTo equalsTo = new EqualsTo();
            equalsTo.setLeftExpression(new Column(column));
            // -999999L表示没有权限
            long id = ids.size() == 1 ? ids.get(0) : -999999L;
            equalsTo.setRightExpression(new LongValue(id));
            return equalsTo;
        }
    }
}

原始的or表达式并没有加括号来提高优先级,所以需要自定义扩展

/**
 * 扩展or表达式,效果是在外面加个括号
 */
public class OrExpressionExtend extends OrExpression {

    private final List<Expression> expressionList;

    public OrExpressionExtend(Expression... expressions) {
        expressionList = new ArrayList<>();
        expressionList.addAll(Arrays.asList(expressions));
    }

    @Override
    public String toString() {
        return "(" + expressionList.stream().map(Object::toString).collect(Collectors.joining(" OR ")) + ")";
    }
}

2.5 配置拦截器

@Configuration
public class MybatisPlusConfiguration {

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        // 添加数据权限插件
        DataPermissionInterceptor dataPermissionInterceptor = new DataPermissionInterceptor();
        // 添加自定义的数据权限处理器
        dataPermissionInterceptor.setDataPermissionHandler(new DefaultDataPermissionHandler());
        interceptor.addInnerInterceptor(dataPermissionInterceptor);

        // 分页插件,要在数据权限的后面
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        return interceptor;
    }
}

3 使用效果

  在对应对应查询方法上加上注解即可

@Service
public class DemoServiceImpl extends BaseServiceImpl<DemolMapper, DemoEntity> implements IDemoService {
    
    @Override
    @DataPermission(sql = "org.springblade.modules.seal.mapper.SealMapper.page", tableAlias = "t1")
    public IPage<SealVO> page(IPage<SealVO> page, SealDTO dto) {
        List<SealVO> pages = baseMapper.page(page, dto);
        // 查询外键信息....并不受数据权限影响
        return page.setRecords(pages);
    }
}

最终查询日志如下,可以作用于分页查询的count和具体查询

select count(*) ...... where t1.is_deleted = 0 and t2.type_code = 1 and (t1.create_dept = -999999 or t1.create_user = 1123598821738675201)

select ...... where t1.is_deleted = 0 and t2.type_code = 1 and (t1.create_dept = -999999 or t1.create_user = 1123598821738675201) order by t1.id desc limit 10