动态管理数据范围 | 青训营笔记

126 阅读4分钟

这是我参与「第五届青训营 」伴学笔记创作活动的第 13 天

我们经常会遇到一些场景,比如一个公司有多个业务部门,每个部门只能看到和操作自己部门的数据,而业务部的上级部门却可以看到所有部门的数据。如何动态地管理每个用户能看到哪些部门的数据(本人 本部门 本部门及子部门 全部)?经过学习,我学到了如下这种方式。

总的来说,思路就是在要这些查询的方法上加一个注解,拦截这个方法要执行的sql语句,并插入划分数据范围的sql语句,而这个插入的sql语句是根据配置和每个请求用户而变化的。

话不多说,直接放代码。

数据表的涉及

create table t_role_data_scope
(
    id              int auto_increment
        primary key,
    data_scope_type int      default 0                 not null comment '数据范围id',
    view_type       int      default 0                 not null comment '查找范围',
    update_type     int      default 0                 null comment '更新范围',
    delete_type     int      default 0                 null comment '删除范围',
    role_id         int      default -1                not null comment '角色id',
    update_time     datetime default CURRENT_TIMESTAMP not null on update CURRENT_TIMESTAMP comment '更新时间',
    create_time     datetime default CURRENT_TIMESTAMP not null comment '创建时间'
);

注解定义

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface DataScope {

    /**
     * 模块作用域,请和自己对应的模块相一致
     * @return
     */
    DataScopeTypeEnum dataScopeType() default DataScopeTypeEnum.DEFAULT;

    /**
     * 三个条件范围,分别是用户id,部门id和自定义规则
     * 如果不是自定义,则USER对应’#userIds‘ ,DEPARTMENT对应’#departmentIds‘
     */
    DataScopeWhereInTypeEnum whereInType() default DataScopeWhereInTypeEnum.USER;

    /**
     * 对应的操作
     * 请根据实际执行的sql进行分配
     * web端可以对每个角色对数据的 查 改 删进行分别配置,如果对应不上会冲突
     */
    DataScopeSqlTypeEnum dataScopeSqlType() default DataScopeSqlTypeEnum.VIEW;

    /**
     * DataScopeWhereInTypeEnum.CUSTOM_STRATEGY类型 才可使用joinSqlImplClazz属性
     * @return
     */
    Class<? extends DataScopePowerStrategy> joinSqlImplClazz()  default DataScopePowerStrategy.class;

    /**
     *
     * 第几个where 条件 从0开始
     * @return
     */
    int whereIndex() default 0;

    /**
     * DataScopeWhereInTypeEnum为CUSTOM_STRATEGY类型时,此属性无效
     * 自己希望增加的sql,作为查改删的条件sql之一,不用加”where“
     * 注意的是,只会根据用户设置的动态替换'#userIds'字段和'#departmentIds'字段
     */
    String joinSql() default "";
}

注解逻辑service

@Slf4j
@Service
public class DataScopeSqlConfigService {

    private ConcurrentHashMap<String, DataScopeSqlConfigDTO> dataScopeMethodMap = new ConcurrentHashMap<>();

    @Autowired
    private DataScopeViewService dataScopeViewService;

    @Value("${swagger.packAge}")
    private String scanPackage;

    /**
     * 注解joinsql 参数
     */
    private static final String USER_PARAM = "#userIds";

    private static final String DEPARTMENT_PARAM = "#departmentIds";

    @PostConstruct
    private void initDataScopeMethodMap() {
        this.refreshDataScopeMethodMap();
        log.info("缓存数据范围注解方法success");
    }

    /**
     * 刷新 所有添加数据范围注解的接口方法配置<class.method,DataScopeSqlConfigDTO></>
     *
     * @return
     */
    private Map<String, DataScopeSqlConfigDTO> refreshDataScopeMethodMap() {
        Reflections reflections = new Reflections(new ConfigurationBuilder().setUrls(ClasspathHelper.forPackage(scanPackage)).setScanners(new MethodAnnotationsScanner()));
        Set<Method> methods = reflections.getMethodsAnnotatedWith(DataScope.class);
        for (Method method : methods) {
            DataScope dataScopeAnnotation = method.getAnnotation(DataScope.class);
            if (dataScopeAnnotation != null) {
                DataScopeSqlConfigDTO configDTO = new DataScopeSqlConfigDTO();
                configDTO.setDataScopeType(dataScopeAnnotation.dataScopeType());
                configDTO.setJoinSql(dataScopeAnnotation.joinSql());
                configDTO.setWhereIndex(dataScopeAnnotation.whereIndex());
                configDTO.setDataScopeWhereInType(dataScopeAnnotation.whereInType());
                configDTO.setDataScopeSqlType(dataScopeAnnotation.dataScopeSqlType());
                dataScopeMethodMap.put(method.getDeclaringClass().getSimpleName() + "." + method.getName(), configDTO);
            }
        }
        return dataScopeMethodMap;
    }

    /**
     * 根据调用的方法获取,此方法的配置信息
     *
     * @param method
     * @return
     */
    public DataScopeSqlConfigDTO getSqlConfig(String method) {
        DataScopeSqlConfigDTO sqlConfigDTO = this.dataScopeMethodMap.get(method);
        return sqlConfigDTO;
    }

    /**
     * 组装需要拼接的sql
     *
     * @param sqlConfigDTO
     * @return
     */
    public String getJoinSql(DataScopeSqlConfigDTO sqlConfigDTO) {
        DataScopeTypeEnum dataScopeTypeEnum = sqlConfigDTO.getDataScopeType();
        DataScopeSqlTypeEnum dataScopeSqlType = sqlConfigDTO.getDataScopeSqlType();
        String joinSql = sqlConfigDTO.getJoinSql();
        Long userId = SmartRequestTokenUtil.getRequestUserId();
        Integer roleId = SmartRequestTokenUtil.getThreadLocalUser().getRoleId();
        if (DataScopeWhereInTypeEnum.CUSTOM_STRATEGY == sqlConfigDTO.getDataScopeWhereInType()) {
            Class strategyClass = sqlConfigDTO.getJoinSqlImplClazz();
            if(strategyClass == null){
                log.warn("data scope custom strategy class is null");
                return "";
            }
            DataScopePowerStrategy powerStrategy = (DataScopePowerStrategy)SmartApplicationContext.getBean(sqlConfigDTO.getJoinSqlImplClazz());
            if (powerStrategy == null) {
                log.warn("data scope custom strategy class:{} ,bean is null",sqlConfigDTO.getJoinSqlImplClazz());
                return "";
            }
            DataScopeViewTypeEnum viewTypeEnum =
                    dataScopeViewService.getEmployeeDataScopeViewType(dataScopeTypeEnum, dataScopeSqlType,roleId);
            return powerStrategy.getCondition(viewTypeEnum,sqlConfigDTO);
        }
        if (DataScopeWhereInTypeEnum.USER == sqlConfigDTO.getDataScopeWhereInType()) {
            List<Long> canViewEmployeeIds = dataScopeViewService.getCanViewEmployeeId(dataScopeTypeEnum,
                    dataScopeSqlType,roleId,userId);
            if (CollectionUtils.isEmpty(canViewEmployeeIds)) {
                return "";
            }
            String employeeIds = StringUtils.join(canViewEmployeeIds, ",");
            String sql = joinSql.replaceAll(USER_PARAM, employeeIds);
            return sql;
        }
        if (DataScopeWhereInTypeEnum.DEPARTMENT == sqlConfigDTO.getDataScopeWhereInType()) {
            List<Long> canViewDepartmentIds = dataScopeViewService.getCanViewDepartmentId(dataScopeTypeEnum,
                    dataScopeSqlType,roleId,userId);
            if (CollectionUtils.isEmpty(canViewDepartmentIds)) {
                return "";
            }
            String departmentIds = StringUtils.join(canViewDepartmentIds, ",");
            String sql = joinSql.replaceAll(DEPARTMENT_PARAM, departmentIds);
            return sql;
        }
        return "";
    }
}
@Slf4j
@Service
public class DataScopeViewService {

    //<<roleId>,<<DataScopeTypeEnum.value,<DataScopeSqlTypeEnum.code,DataScopeViewTypeEnum>>>>>
    private ConcurrentHashMap<Integer,HashMap<Integer,HashMap<Integer,DataScopeViewTypeEnum>>> dataScopeMap = new ConcurrentHashMap<>();

    @Autowired
    private WorkWxService workWxService;

    @Autowired
    private DataScopeRoleDao dataScopeRoleDao;

    @Autowired
    private DepartmentTreeService departmentTreeService;

    @Autowired
    private UserMapper userMapper;

    /**
     * 获取某人可以查看的所有人员信息
     *
     * @param dataScopeTypeEnum
     * @param employeeId
     * @return
     */
    public List<Long> getCanViewEmployeeId(DataScopeTypeEnum dataScopeTypeEnum,
                                           DataScopeSqlTypeEnum dataScopeSqlTypeEnum,
                                           Integer roleId,
                                           Long employeeId) {
        DataScopeViewTypeEnum viewType = this.getEmployeeDataScopeViewType(dataScopeTypeEnum, dataScopeSqlTypeEnum,roleId);
        if (DataScopeViewTypeEnum.ME == viewType) {
            return this.getMeEmployeeIdList(employeeId);
        }
        if (DataScopeViewTypeEnum.DEPARTMENT == viewType) {
            return this.getDepartmentEmployeeIdList(employeeId);
        }
        if (DataScopeViewTypeEnum.DEPARTMENT_AND_SUB == viewType) {
            return this.getDepartmentAndSubEmployeeIdList(employeeId);
        }
        return Lists.newArrayList();
    }

    /**
     * 获取某人可以查看的所有部门信息
     *
     * @param dataScopeTypeEnum
     * @param employeeId
     * @return
     */
    public List<Long> getCanViewDepartmentId(DataScopeTypeEnum dataScopeTypeEnum,
                                             DataScopeSqlTypeEnum dataScopeSqlTypeEnum,
                                             Integer roleId,
                                             Long employeeId) {
        DataScopeViewTypeEnum viewType = this.getEmployeeDataScopeViewType(dataScopeTypeEnum, dataScopeSqlTypeEnum,roleId);
        if (DataScopeViewTypeEnum.ME == viewType) {
            return this.getMeDepartmentIdList(employeeId);
        }
        if (DataScopeViewTypeEnum.DEPARTMENT == viewType) {
            return this.getMeDepartmentIdList(employeeId);
        }
        if (DataScopeViewTypeEnum.DEPARTMENT_AND_SUB == viewType) {
            return this.getDepartmentAndSubIdList(employeeId);
        }
        return Lists.newArrayList();
    }

    private List<Long> getMeDepartmentIdList(Long employeeId) {
        Integer departmentId = SmartRequestTokenUtil.getThreadLocalUser().getDepartmentId();
        return Collections.singletonList((long)departmentId);
    }

    private List<Long> getDepartmentAndSubIdList(Long employeeId) {
        Integer departmentId = SmartRequestTokenUtil.getThreadLocalUser().getDepartmentId();
        List<Long> allDepartmentIds = Lists.newArrayList();
        departmentTreeService.buildIdList((long) departmentId, allDepartmentIds);
        return allDepartmentIds;
    }

    /**
     * 根据角色id 获取各数据范围最大的可见范围 map<dataScopeType,viewType></>
     *
     * @return
     */
    public DataScopeViewTypeEnum getEmployeeDataScopeViewType(DataScopeTypeEnum dataScopeTypeEnum,
                                                              DataScopeSqlTypeEnum dataScopeSqlTypeEnum,
                                                              Integer roleId) {
        //未设置角色,默认本人
        if(roleId == null) {
            return DataScopeViewTypeEnum.ME;
        }
        //管理员返回全部
        if(roleId.equals(RoleEnum.ADMIN.getCode())) {
            return DataScopeViewTypeEnum.ALL;
        }

        //获取角色数据范围
        return this.getViewType(roleId,dataScopeTypeEnum,dataScopeSqlTypeEnum);
    }

    /**
     * 获取本人相关 可查看员工id
     *
     * @param employeeId
     * @return
     */
    private List<Long> getMeEmployeeIdList(Long employeeId) {
        return Lists.newArrayList(employeeId);
    }

    /**
     * 获取本部门相关 可查看员工id
     *
     * @param employeeId
     * @return
     */
    private List<Long> getDepartmentEmployeeIdList(Long employeeId) {
        String wcUserId = SmartRequestTokenUtil.getThreadLocalUser().getWcUserId();
        log.info("wcUserId:{}",wcUserId);
        try {
            List<String> userIdList = workWxService.getSameDepartUserIdByUserId(wcUserId);
            log.info("userIdList:{}",userIdList);
            List<UserSimpleVo> userSimpleVos = userMapper.listInfoByUserId(userIdList);
            return userSimpleVos.stream().map(UserSimpleVo::getId).distinct().collect(Collectors.toList());
        } catch (WxErrorException e) {
            e.printStackTrace();
            return Lists.newArrayList(employeeId);
        }
    }

    /**
     * 获取本部门及下属子部门相关 可查看员工id
     *
     * @param employeeId
     * @return
     */
    private List<Long> getDepartmentAndSubEmployeeIdList(Long employeeId) {
        String wcUserId = SmartRequestTokenUtil.getThreadLocalUser().getWcUserId();
        try {
            List<String> userIdList = workWxService.getSameAndSonDepartUserIdByUserId(wcUserId);
            List<UserSimpleVo> userSimpleVos = userMapper.listInfoByUserId(userIdList);
            return userSimpleVos.stream().map(UserSimpleVo::getId).distinct().collect(Collectors.toList());
        } catch (WxErrorException e) {
            e.printStackTrace();
            return Lists.newArrayList(employeeId);
        }
    }

    private DataScopeViewTypeEnum getViewType(Integer roleId, DataScopeTypeEnum dataScopeTypeEnum,
                                              DataScopeSqlTypeEnum dataScopeSqlTypeEnum) {
        HashMap<Integer, HashMap<Integer, DataScopeViewTypeEnum>> dataScopeTypeMap = dataScopeMap.get(roleId);
        if(Objects.isNull(dataScopeTypeMap)) {
            return DataScopeViewTypeEnum.ME;
        }
        Integer dataScopeType = dataScopeTypeEnum.getValue();
        HashMap<Integer, DataScopeViewTypeEnum> sqlTypeMap =
                dataScopeTypeMap.get(dataScopeType);
        if(Objects.isNull(sqlTypeMap)) {
            return DataScopeViewTypeEnum.ME;
        }
        Integer sqlType = dataScopeSqlTypeEnum.getCode();
        DataScopeViewTypeEnum dataScopeViewTypeEnum = sqlTypeMap.get(sqlType);
        if(Objects.isNull(dataScopeSqlTypeEnum)) {
            return DataScopeViewTypeEnum.ME;
        }
        return dataScopeViewTypeEnum;
    }

    @PostConstruct
    private void buildRoleViewDataMap() {
        this.refreshDataScopeViewTypeMap(null);
        log.info(dataScopeMap.toString());
        log.info("缓存角色数据范围success");
    }

    public void refreshDataScopeViewTypeMap(Integer roleId) {
        List<DataScopeRoleEntity> dataScopeRoleEntities = dataScopeRoleDao.listAll(roleId);
        if(CollectionUtils.isEmpty(dataScopeRoleEntities)) {
            return;
        }
        ConcurrentMap<Integer, List<DataScopeRoleEntity>> groupByRoleId =
                dataScopeRoleEntities.stream().collect(Collectors.groupingByConcurrent(DataScopeRoleEntity::getRoleId));
        groupByRoleId.forEach((roleKey,dataScopeRoleEntityList) -> {
            HashMap<Integer, HashMap<Integer, DataScopeViewTypeEnum>> dataScopeTypeMap = new HashMap<>();
            Map<Integer, DataScopeRoleEntity> groupByDataScopeType =
                    dataScopeRoleEntityList.stream().collect(Collectors.toMap(DataScopeRoleEntity::getDataScopeType,
                            Function.identity()));
            groupByDataScopeType.forEach((dataScopeTypeKey,dataScopeRoleEntity)-> {
                HashMap<Integer, DataScopeViewTypeEnum> dataScopeViewTypeMap = new HashMap<>();
                dataScopeViewTypeMap.put(DataScopeSqlTypeEnum.VIEW.getCode(),
                        SmartBaseEnumUtil.getEnumByValue(dataScopeRoleEntity.getViewType(),
                                DataScopeViewTypeEnum.class));
                dataScopeViewTypeMap.put(DataScopeSqlTypeEnum.UPDATE.getCode(),
                        SmartBaseEnumUtil.getEnumByValue(dataScopeRoleEntity.getUpdateType(),
                                DataScopeViewTypeEnum.class));
                dataScopeViewTypeMap.put(DataScopeSqlTypeEnum.DELETE.getCode(),
                        SmartBaseEnumUtil.getEnumByValue(dataScopeRoleEntity.getDeleteType(),
                                DataScopeViewTypeEnum.class));
                dataScopeTypeMap.put(dataScopeTypeKey,dataScopeViewTypeMap);
            });
            dataScopeMap.put(roleKey,dataScopeTypeMap);
        });
    }
}

sql拦截器

@Slf4j
public class SqlInterceptor implements InnerInterceptor {

    @SneakyThrows
    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        DataScopeSqlConfigDTO sqlConfigDTO = this.getSqlConfigDTO(ms);
        if(sqlConfigDTO != null) {
            //原始sql
            String originalSql = boundSql.getSql().trim();
            //新sql
            String newSql = this.joinSql(originalSql, sqlConfigDTO);
            //通过反射修改sql语句
            Field field = boundSql.getClass().getDeclaredField("sql");
            field.setAccessible(true);
            field.set(boundSql,newSql);
        }
    }

    private DataScopeSqlConfigDTO getSqlConfigDTO(MappedStatement ms) {
        //获取执行方法的位置
        String namespace = ms.getId();
        //获取配置配置
        String[] split = namespace.split("\.");
        int length = split.length;
        String pathKey = split[length - 2] + "." + split[length - 1];
        DataScopeSqlConfigService dataScopeSqlConfigService = this.dataScopeSqlConfigService();
        if(dataScopeSqlConfigService == null) {
            return null;
        }
        return dataScopeSqlConfigService.getSqlConfig(pathKey);
    }

    @SneakyThrows
    @Override
    public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        DataScopeSqlConfigDTO sqlConfigDTO = this.getSqlConfigDTO(ms);
        if(sqlConfigDTO != null) {
            BoundSql boundSql = ms.getBoundSql(parameter);
            String originalSql = boundSql.getSql();
            String newSql = this.joinSql(originalSql, sqlConfigDTO);
            //通过反射修改sql语句
            Field field = boundSql.getClass().getDeclaredField("sql");
            field.setAccessible(true);
            field.set(boundSql,newSql);
        }
    }

    public DataScopeSqlConfigService dataScopeSqlConfigService() {
        return (DataScopeSqlConfigService) SmartApplicationContext.getBean("dataScopeSqlConfigService");
    }

    private String joinSql(String sql, DataScopeSqlConfigDTO sqlConfigDTO) {
        if (null == sqlConfigDTO) {
            return sql;
        }
        String appendSql = this.dataScopeSqlConfigService().getJoinSql(sqlConfigDTO);
        if (StringUtils.isEmpty(appendSql)) {
            return sql;
        }
        Integer appendSqlWhereIndex = sqlConfigDTO.getWhereIndex();
        String where = "where";
        String order = "order by";
        String group = "group by";
        int whereIndex = StringUtils.ordinalIndexOf(sql.toLowerCase(), where, appendSqlWhereIndex + 1);
        int orderIndex = sql.toLowerCase().indexOf(order);
        int groupIndex = sql.toLowerCase().indexOf(group);
        if (whereIndex > - 1) {
            String subSql = sql.substring(0, whereIndex + where.length() + 1);
            subSql = subSql + " " + appendSql + " AND " + sql.substring(whereIndex + where.length() + 1);
            return subSql;
        }

        if (groupIndex > - 1) {
            String subSql = sql.substring(0, groupIndex);
            subSql = subSql + " where " + appendSql + " " + sql.substring(groupIndex);
            return subSql;
        }
        if (orderIndex > - 1) {
            String subSql = sql.substring(0, orderIndex);
            subSql = subSql + " where " + appendSql + " " + sql.substring(orderIndex);
            return subSql;
        }
        sql += " where " + appendSql;
        return sql;
    }
}