mybatis-plus 多租户 配置权限

821 阅读2分钟
package com.small.account.config;

import com.baomidou.mybatisplus.core.config.GlobalConfig;
import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantSqlParser;
import com.common.basis.entity.User;
import com.common.basis.util.StringMappingTool;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.ValueListExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.TokenStore;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * 分页配置
 *
 * @author wangguochao
 */
@Configuration
@Slf4j
@EnableConfigurationProperties(TenantProperties.class)
@AllArgsConstructor
public class MyBatisPlusConfig {

    private final TenantProperties tenantProperties;
    private final TokenStore tokenStore;

    /**
     * 自动填充功能
     */
    @Bean
    public GlobalConfig globalConfig() {
        GlobalConfig globalConfig = new GlobalConfig();
        globalConfig.setMetaObjectHandler(new MetaHandler());
        return globalConfig;
    }

    /**
     * 分页
     */
    @Bean
    public PaginationInterceptor paginationInterceptor() {
        PaginationInterceptor paginationInterceptor = new PaginationInterceptor();
        // 创建SQL解析器集合
        List<ISqlParser> sqlParserList = new ArrayList<>();

        // 创建租户SQL解析器
        TenantSqlParser tenantSqlParser = new TenantSqlParser();
        // 设置租户处理器
        tenantSqlParser.setTenantHandler(new TenantHandler() {
            /**
             * 获取租户 ID 值表达式,只支持单个 ID 值
             * <p>
             *
             * @return 租户 ID 值表达式
             */
            @Override
            public Expression getTenantId(boolean select) {
                String account = "";
                //当前登录人租户
                if (SecurityContextHolder.getContext().getAuthentication() != null) {
                    OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
                    Authentication userAuthentication = authentication.getUserAuthentication();
                    User user = (User) userAuthentication.getPrincipal();
                    account = user.getTenantId();
                }
                //是否携带参数
                HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
                String tenantId = request.getHeader("tenantId");
                if (account.equals(StringMappingTool.SYSTEM) && StringUtils.isNotBlank(tenantId)) {
                    account = tenantId;
                }
                // 设置当前租户ID,实际情况你可以从cookie、或者缓存中拿都行
                log.info("当前租户为{}", Objects.requireNonNull(account));
                return new StringValue(account);
            }

            /**
             * 获取租户字段名
             * <p>
             * 默认字段名叫: tenant_id
             *
             * @return 租户字段名
             */
            @Override
            public String getTenantIdColumn() {
                // 对应数据库租户ID的列名
                return "tenant_id";
            }

            /**
             * 做表过滤器
             *
             * @param tableName 表名
             * @return boolean
             */
            @Override
            public boolean doTableFilter(String tableName) {
                HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
                String tenantId = request.getHeader("tenantId");
                if (tenantProperties.getSysTables().contains(tableName)) {
                    return true;
                }
                String account = null;
                if (SecurityContextHolder.getContext().getAuthentication() != null) {
                    OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
                    Authentication userAuthentication = authentication.getUserAuthentication();
                    User user = (User) userAuthentication.getPrincipal();
                    account = user.getTenantId();
                }
                return tenantProperties.getAdmin().contains(account) && StringUtils.isBlank(tenantId);
            }
        });
        // 创建部门SQL解析器
        TenantSqlParser mechanismSqlParser = new TenantSqlParser();
        // 设置部门处理器
        mechanismSqlParser.setTenantHandler(new TenantHandler() {
            /**
             * 获取部门 ID 值表达式,ID,IDs 值
             * <p>
             *
             * @return 部门 ID 值表达式
             */
            @Override
            public Expression getTenantId(boolean select) {
                List<String> mechanism = new ArrayList<>();
                //当前登录人租户
                if (SecurityContextHolder.getContext().getAuthentication() != null) {
                    OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
                    Authentication userAuthentication = authentication.getUserAuthentication();
                    User user = (User) userAuthentication.getPrincipal();
                    mechanism = user.getRoleAndMechanismIds();
                }
                //是否携带参数
                HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
                String mechanismIds = request.getHeader("mechanismIds");
                if (StringUtils.isNotBlank(mechanismIds)) {
                    mechanism = Arrays.asList(mechanismIds.split(","));
                }
                // 设置当前租户ID,实际情况你可以从cookie、或者缓存中拿都行
                //****重点是这里****
                log.info("当前机构为{}", Objects.requireNonNull(mechanism));
                ValueListExpression inExpression = new ValueListExpression();
                ExpressionList expressionList = new ExpressionList();
                List<Expression> values = mechanism.stream().map(v -> (Expression) new StringValue(v)).collect(Collectors.toList());
                expressionList.setExpressions(values);
                inExpression.setExpressionList(expressionList);
                return inExpression;
            }

            /**
             * 获取机构字段名
             * <p>
             * 默认字段名叫: au_mechanism_id
             *
             * @return 机构字段名
             */
            @Override
            public String getTenantIdColumn() {
                // 对应数据库租户ID的列名
                return "au_mechanism_id";
            }

            /**
             * 做表过滤器
             *
             * @param tableName 表名
             * @return boolean
             */
            @Override
            public boolean doTableFilter(String tableName) {
                HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
                String tenantId = request.getHeader("tenantId");
                if (!tenantProperties.getTables().contains(tableName)) {
                    return true;
                }
                String account = null;
                if (SecurityContextHolder.getContext().getAuthentication() != null) {
                    OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
                    Authentication userAuthentication = authentication.getUserAuthentication();
                    User user = (User) userAuthentication.getPrincipal();
                    account = user.getTenantId();
                }
                return tenantProperties.getAdmin().contains(account) && StringUtils.isBlank(tenantId);
            }
        });
        sqlParserList.add(tenantSqlParser);
        sqlParserList.add(mechanismSqlParser);
        paginationInterceptor.setSqlParserList(sqlParserList);
        // 设置请求的页面大于最大页后操作, true调回到首页,false 继续请求  默认false
        paginationInterceptor.setOverflow(false);
        // 设置最大单页限制数量,默认 500 条,-1 不受限制
        paginationInterceptor.setLimit(20);
        return paginationInterceptor;
    }
}