package com.small.account.config;
import com.baomidou.mybatisplus.core.config.GlobalConfig;
import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantSqlParser;
import com.common.basis.entity.User;
import com.common.basis.util.StringMappingTool;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.ValueListExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.TokenStore;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Configuration
@Slf4j
@EnableConfigurationProperties(TenantProperties.class)
@AllArgsConstructor
public class MyBatisPlusConfig {
private final TenantProperties tenantProperties;
private final TokenStore tokenStore;
@Bean
public GlobalConfig globalConfig() {
GlobalConfig globalConfig = new GlobalConfig();
globalConfig.setMetaObjectHandler(new MetaHandler());
return globalConfig;
}
@Bean
public PaginationInterceptor paginationInterceptor() {
PaginationInterceptor paginationInterceptor = new PaginationInterceptor();
List<ISqlParser> sqlParserList = new ArrayList<>();
TenantSqlParser tenantSqlParser = new TenantSqlParser();
tenantSqlParser.setTenantHandler(new TenantHandler() {
@Override
public Expression getTenantId(boolean select) {
String account = "";
if (SecurityContextHolder.getContext().getAuthentication() != null) {
OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
Authentication userAuthentication = authentication.getUserAuthentication();
User user = (User) userAuthentication.getPrincipal();
account = user.getTenantId();
}
HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
String tenantId = request.getHeader("tenantId");
if (account.equals(StringMappingTool.SYSTEM) && StringUtils.isNotBlank(tenantId)) {
account = tenantId;
}
log.info("当前租户为{}", Objects.requireNonNull(account));
return new StringValue(account);
}
@Override
public String getTenantIdColumn() {
return "tenant_id";
}
@Override
public boolean doTableFilter(String tableName) {
HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
String tenantId = request.getHeader("tenantId");
if (tenantProperties.getSysTables().contains(tableName)) {
return true;
}
String account = null;
if (SecurityContextHolder.getContext().getAuthentication() != null) {
OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
Authentication userAuthentication = authentication.getUserAuthentication();
User user = (User) userAuthentication.getPrincipal();
account = user.getTenantId();
}
return tenantProperties.getAdmin().contains(account) && StringUtils.isBlank(tenantId);
}
});
TenantSqlParser mechanismSqlParser = new TenantSqlParser();
mechanismSqlParser.setTenantHandler(new TenantHandler() {
@Override
public Expression getTenantId(boolean select) {
List<String> mechanism = new ArrayList<>();
if (SecurityContextHolder.getContext().getAuthentication() != null) {
OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
Authentication userAuthentication = authentication.getUserAuthentication();
User user = (User) userAuthentication.getPrincipal();
mechanism = user.getRoleAndMechanismIds();
}
HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
String mechanismIds = request.getHeader("mechanismIds");
if (StringUtils.isNotBlank(mechanismIds)) {
mechanism = Arrays.asList(mechanismIds.split(","));
}
log.info("当前机构为{}", Objects.requireNonNull(mechanism));
ValueListExpression inExpression = new ValueListExpression();
ExpressionList expressionList = new ExpressionList();
List<Expression> values = mechanism.stream().map(v -> (Expression) new StringValue(v)).collect(Collectors.toList());
expressionList.setExpressions(values);
inExpression.setExpressionList(expressionList);
return inExpression;
}
@Override
public String getTenantIdColumn() {
return "au_mechanism_id";
}
@Override
public boolean doTableFilter(String tableName) {
HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
String tenantId = request.getHeader("tenantId");
if (!tenantProperties.getTables().contains(tableName)) {
return true;
}
String account = null;
if (SecurityContextHolder.getContext().getAuthentication() != null) {
OAuth2Authentication authentication = (OAuth2Authentication) SecurityContextHolder.getContext().getAuthentication();
Authentication userAuthentication = authentication.getUserAuthentication();
User user = (User) userAuthentication.getPrincipal();
account = user.getTenantId();
}
return tenantProperties.getAdmin().contains(account) && StringUtils.isBlank(tenantId);
}
});
sqlParserList.add(tenantSqlParser);
sqlParserList.add(mechanismSqlParser);
paginationInterceptor.setSqlParserList(sqlParserList);
paginationInterceptor.setOverflow(false);
paginationInterceptor.setLimit(20);
return paginationInterceptor;
}
}