一个查询语句在 ShardingJDBC 中都发生了啥

366 阅读20分钟

随着业务数据的增长,原本公司中的分库分表方案相对来说不够灵活,所以决定引入业内相对较为成熟的分库分表组件。通过接入成本、性能损耗、社区活跃等多方面考虑,决定引入 SharidngJDBC。但是目前 ShardingJDBC 对部分现有业务无法做到太友好的支持,所以决定基于现有扩展点扩展或者改造源码,为此需要深入阅读 ShardingJDBC 源码。

ShardingJDBC 架构图

以上是从官网拷贝的架构图,图上可以很清晰的看到 SQL 大致流程:

  1. 解析引擎进行 SQL 解析,根据数据库类型将 SQL 解析成抽象语法树
  2. 路由引擎根据路由规则和抽象语法树,生成路由上下文
  3. 改写引擎根据路由上下文和抽象语法树,对 SQL 进行改写,并生成改写上下文
  4. 根据路由上下文、改写上下文生成执行上下文
  5. 执行引擎执行上下文根据库、连接数等进行分组,并发执行 SQL
  6. 合并引擎根据路由规则等对执行结果进行合并等。

以上是 SQL 大致的执行流程,接下来就从源码层面大致看一下具体执行流程。注:本文是基于 5.5.0 的源码。

解析引擎

首先看一下 SQL 解析是在哪个阶段执行,再看看具体解析的逻辑。以下是 SQL 解析执行的阶段:

// 预处理语句
PreparedStatement preparedStatement = connection.prepareStatement(sql);

// org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection#prepareStatement
public PreparedStatement prepareStatement(final String sql) throws SQLException {
    return new ShardingSpherePreparedStatement(this, sql);
}

// org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSpherePreparedStatement#ShardingSpherePreparedStatement
private ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql,
                                          final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys,
                                          final String[] columns) throws SQLException {
    // 略...
    SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
    // 获取 sql 解析引擎
    SQLParserEngine sqlParserEngine = sqlParserRule.getSQLParserEngine(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType());
    // 可以通过 SQLParserRuleConfiguration 配置缓存
    sqlStatement = sqlParserEngine.parse(this.sql, true);
    // 
    sqlStatementContext = new SQLBindEngine(metaDataContexts.getMetaData(), connection.getDatabaseName(), hintValueContext).bind(sqlStatement, Collections.emptyList());
    // 略...
}

从上述可以看到 SQL 解析是在预处理阶段执行的,除了 SQL 解析预处理阶段还初始化了内核处理器,各种执行器等等。

接下来看看 SQL 解析具体是如何执行的,SQL 解析底层依赖于 Antlr4 (不了解的可以看看这边文章 传送门 ):

// org.apache.shardingsphere.infra.parser.sql.SQLStatementParserEngine#parse
public SQLStatement parse(final String sql, final boolean useCache) {
    // 如果缓存中存在sql对应的 SQLStatement 则直接从缓存中获取,底层使用的 caffeine
    return useCache ? sqlStatementCache.get(sql) : sqlStatementParserExecutor.parse(sql);
}

// org.apache.shardingsphere.infra.parser.sql.SQLStatementParserExecutor#parse
public SQLStatement parse(final String sql) {
    // 1、基于 Antlr4 语法规则生成抽象语法树
    // 2、基于 Antlr4 vistor 生成 SQLStatement
    return visitorEngine.visit(parserEngine.parse(sql, false));
}

// org.apache.shardingsphere.sql.parser.api.SQLParserEngine#parse
public ParseASTNode parse(final String sql, final boolean useCache) {
    // 通常是优先从缓存中取,如果有则使用缓存,没有则解析
    return useCache ? parseTreeCache.get(sql) : sqlParserExecutor.parse(sql);
}

// org.apache.shardingsphere.sql.parser.core.database.parser.SQLParserExecutor#parse
public ParseASTNode parse(final String sql) {
    ParseASTNode result = twoPhaseParse(sql);
    //...
    return result;
}

private ParseASTNode twoPhaseParse(final String sql) {
    // 基于SPI获取数据库解析器门面,其中包含语法解析器和词法解析器类型
    DialectSQLParserFacade sqlParserFacade = DatabaseTypedSPILoader.getService(DialectSQLParserFacade.class, databaseType);
    // 初始化语法解析器,此处语法解析器是扩展自 Antlr4 生成的代码
    SQLParser sqlParser = SQLParserFactory.newInstance(sql, sqlParserFacade.getLexerClass(), sqlParserFacade.getParserClass());
    ((Parser) sqlParser).getInterpreter().setPredictionMode(PredictionMode.SLL);
    return (ParseASTNode) sqlParser.parse();
}

以上就是 SQL 解析的具体解析逻辑,底层使用 Antlr4,主要是分为两个阶段:

  1. 使用 Antlr4 生成的语法解析器解析 SQL,生成抽象语法树
  2. 使用自定义的 Vistor 遍历抽象语法树,并生成对应的 SQLStatement

下面再简单看看 ShardingJDBC 是如何使用 Antlr4 的,具体代码在 shardingsphere-parser 这个模块。

以 MySQL Insert 语句的语法规则为例:

grammar DMLStatement;

import BaseRule;

insert
    : INSERT insertSpecification INTO? tableName partitionNames? (insertValuesClause | setAssignmentsClause | insertSelectClause) onDuplicateKeyClause?
    ;
    
insertSpecification
    : (LOW_PRIORITY | DELAYED | HIGH_PRIORITY)? IGNORE?
    ;

insertValuesClause
    : (LP_ fields? RP_ )? (VALUES | VALUE) (assignmentValues (COMMA_ assignmentValues)* | rowConstructorList) valueReference?
    ;

//org.apache.shardingsphere.sql.parser.mysql.parser.MySQLParser#parse
public ASTNode parse() {
    // execute 是根据 Antlr4 生成的语法规则,tokenStream 是词法解析器解析出来的 Token 流
    return new ParseASTNode(execute(), (CommonTokenStream) getTokenStream());
}

以上是抽象语法树解析相关的内容,可以看到主要就是使用 Antlr4 的语法规则文件生成语法解析相关的代码:

  • execute 是 MySQL 语法解析的总入口,定义在 MySQLStatement.g4 中。
  • getTokenStream 是获取词法解析器解析出的 TokenStream。
  • parse 方法实际上做的就是将 Antlr4 解析出的抽象语法树和 TokenStream 封装起来,用于后续 Visitor 遍历使用。

接下来再看看 SQLStatement 生成的过程:

// org.apache.shardingsphere.sql.parser.api.SQLStatementVisitorEngine#visit
public SQLStatement visit(final ParseASTNode parseASTNode) {
    // 初始化 visitor, visitor 继承自 Antlr 生成代码
    SQLStatementVisitor visitor = SQLStatementVisitorFactory.newInstance(databaseType, SQLVisitorRule.valueOf(parseASTNode.getRootNode().getClass()));
    // 遍历语法树,生成 SQLStatement
    ASTNode result = parseASTNode.getRootNode().accept(visitor);
    appendSQLComments(parseASTNode, result);
    return (SQLStatement) result;
}

// org.apache.shardingsphere.sql.parser.mysql.visitor.statement.MySQLStatementVisitor#visitInsert
public ASTNode visitInsert(final InsertContext ctx) {
    // TODO :FIXME, since there is no segment for insertValuesClause, InsertStatement is created by sub rule.
    MySQLInsertStatement result;
    
    if (null != ctx.insertValuesClause()) {
        // 根据 insert value 语法生成 Statemen
        result = (MySQLInsertStatement) visit(ctx.insertValuesClause());
    } else if (null != ctx.insertSelectClause()) {
        // 根据 insert select 语法生成 Statemen
        result = (MySQLInsertStatement) visit(ctx.insertSelectClause());
    } else 
        // assignment 
        result = new MySQLInsertStatement();
        result.setSetAssignment((SetAssignmentSegment) visit(ctx.setAssignmentsClause()));
    }
    // 处理 on duplicate
    if (null != ctx.onDuplicateKeyClause()) {
        result.setOnDuplicateKeyColumns((OnDuplicateKeyColumnsSegment) visit(ctx.onDuplicateKeyClause()));
    }
    // 设置表名
    result.setTable((SimpleTableSegment) visit(ctx.tableName()));
    // 这个暂时不知道作用
    result.addParameterMarkerSegments(getParameterMarkerSegments());
    return result;
}

Visitor 的作用主要就是自定义遍历抽象语法树的逻辑,Antlr4 定义的每个语法规则都会生成一个 visit 方法,为了方便理解,可以简单的认为每个节点都会生成一个 visit 方法,visit 方法的入参是抽象语法树的根节点,传入的如果是非叶子节点(子节点)则遍历的是子树。

接下来简单分析一下 visitInsert 方法:

  • visitInsert 是 Insert 语法子树的入口,当执行 parseASTNode.getRootNode().accept(visitor) 会从 visitExecute 开始执行,最终执行到 visitInsert。
  • 如果 insertValuesClause 子树不为空,则遍历 insertValuesClause 子树。
  • 如果 insertSelectClause 子树不为空,则遍历 insertSelectClause 子树。
  • 其他步骤大多类似,最终将遍历各个子树的结果 (Statement) 组装成 Statement 返回。

执行查询

接下来看一下执行查询相关逻辑:

@Override
public ResultSet executeQuery() throws SQLException {
    ResultSet result;
    try {
        // 创建查询上下文
        QueryContext queryContext = createQueryContext();
        // 开启事务
        handleAutoCommit(queryContext);
        // 流量治理相关?暂时不知道作用
        trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
        if (null != trafficInstanceId) {
            JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
            return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeQuery());
        }
        // todo 联邦查询相关
        useFederation = decide(queryContext,
                metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData());
        if (useFederation) {
            return executeFederationQuery(queryContext);
        }
        // 创建执行上下文
        executionContext = createExecutionContext(queryContext);
        // 执行查询
        result = doExecuteQuery(executionContext);
        // CHECKSTYLE:OFF
    } catch (final RuntimeException ex) {
        // CHECKSTYLE:ON
        handleExceptionInTransaction(connection, metaDataContexts);
        throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
    } 
    return result;
}

// org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSpherePreparedStatement#createExecutionContext
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
    RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
    ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName);
    // sql 审计检查
    SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
    // ☆☆☆☆核心☆☆☆☆ 生成执行上下文
    ExecutionContext result = kernelProcessor.generateExecutionContext(
            queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
    findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
    return result;
}

// org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSpherePreparedStatement#doExecuteQuery
private ShardingSphereResultSet doExecuteQuery(final ExecutionContext executionContext) throws SQLException {
    // 执行查询
    List<QueryResult> queryResults = executeQuery0(executionContext);
    // ☆☆☆☆核心☆☆☆☆ 合并结果、加密、脱敏
    MergedResult mergedResult = mergeQuery(queryResults, executionContext.getSqlStatementContext());
    List<ResultSet> resultSets = getResultSets();
    if (null == columnLabelAndIndexMap) {
        columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
    }
    return new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap);
}

// org.apache.shardingsphere.infra.connection.kernel.KernelProcessor#generateExecutionContext
public ExecutionContext generateExecutionContext(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData,
                                                     final ConfigurationProperties props, final ConnectionContext connectionContext) {
    // 路由,包括分片、读写分离、数据源匹配等等
    // 根据查询上下文、数据库对象、全局规则元数据等参数进行路由,得到路由上下文
    RouteContext routeContext = route(queryContext, database, globalRuleMetaData, props, connectionContext);
    // 根据路由结果改写sql
    // 根据路由上下文和其他参数,对查询语句进行改写,得到改写结果
    SQLRewriteResult rewriteResult = rewrite(queryContext, database, globalRuleMetaData, props, routeContext, connectionContext);
    // 根据查询上下文、数据库对象、路由上下文、改写结果等信息创建执行上下文
    ExecutionContext result = createExecutionContext(queryContext, database, routeContext, rewriteResult);
    logSQL(queryContext, props, result);
    return result;
}

执行查询主要的逻辑包括执行上下文的创建,以及具体执行查询逻辑。其中联邦查询逻辑由于是实验特性,之后再详看。

  • 执行上下文创建主要包括 SQL 审计、分片路由(分库分表、读写分离、影子库)、SQL 改写(分库分表改写、数据加密)
  • 具体执行逻辑主要包括执行上下文分组、根据分组并发执行、查询结果合并(分片结果合并、分页结果合并、数据脱敏、数据加密)

SQL 审计引擎

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class SQLAuditEngine {
    
    
    @SuppressWarnings({"rawtypes", "unchecked"})
    public static void audit(final SQLStatementContext sqlStatementContext, final List<Object> params,
                             final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database, final Grantee grantee, final HintValueContext hintValueContext) {
        Collection<ShardingSphereRule> rules = new LinkedList<>(globalRuleMetaData.getRules());
        if (null != database) {
            rules.addAll(database.getRuleMetaData().getRules());
        }
        // 根据分片规则获取对应的审计规则
        for (Entry<ShardingSphereRule, SQLAuditor> entry : OrderedSPILoader.getServices(SQLAuditor.class, rules).entrySet()) {
            entry.getValue().audit(sqlStatementContext, params, grantee, globalRuleMetaData, database, entry.getKey(), hintValueContext);
        }
    }
}

审计规则比较简单,主要就是基于 SPI 获取规则对应的审计规则,例如默认的分片审计规则不允许 SQL 中没有分片条件。

路由引擎

// org.apache.shardingsphere.infra.route.engine.impl.PartialSQLRouteExecutor#route
public RouteContext route(final ConnectionContext connectionContext, final QueryContext queryContext, final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database) {
    RouteContext result = new RouteContext();
    // 强制路由库
    Optional<String> dataSourceName = findDataSourceByHint(queryContext.getHintValueContext(), database.getResourceMetaData().getStorageUnits());
    if (dataSourceName.isPresent()) {
        result.getRouteUnits().add(new RouteUnit(new RouteMapper(dataSourceName.get(), dataSourceName.get()), Collections.emptyList()));
        return result;
    }
    // 装饰器装饰路由上下文,这里的 routers 根据配置的规则而确定,例如 shardingRule,readWriteSplitRule 都会有对应的 routers
    for (Entry<ShardingSphereRule, SQLRouter> entry : routers.entrySet()) {
        if (result.getRouteUnits().isEmpty()) {
            result = entry.getValue().createRouteContext(queryContext, globalRuleMetaData, database, entry.getKey(), props, connectionContext);
        } else {
            entry.getValue().decorateRouteContext(result, queryContext, database, entry.getKey(), props, connectionContext);
        }
    }
    if (result.getRouteUnits().isEmpty() && 1 == database.getResourceMetaData().getStorageUnits().size()) {
        String singleDataSourceName = database.getResourceMetaData().getStorageUnits().keySet().iterator().next();
        result.getRouteUnits().add(new RouteUnit(new RouteMapper(singleDataSourceName, singleDataSourceName), Collections.emptyList()));
    }
    return result;
}

简单分析下上面的代码:

  • 如果是强制路由数据库,则直接返回。
  • 根据规则获取 SQLRouter,SQLRouter 由 SPI 加载,其中包括单表、分库分表、广播、读写分离影子库等多种 SQLRouter。
  • SQLRouter 采用装饰器模式,SQLRouter 中定义了优先级,优先级最高的 SQLRouter 创建路由上下文,之后的 SQLRouter 负责装饰路由上下文。

接下来看一下分库分表和读写分离,其他的感兴趣的可以自行看看源码,首先看一下分库分表:

// org.apache.shardingsphere.sharding.route.engine.ShardingSQLRouter#createRouteContext0
private RouteContext createRouteContext0(final QueryContext queryContext, final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database, final ShardingRule rule,
                                             final ConfigurationProperties props, final ConnectionContext connectionContext) {
    // 获取查询上下文中的 sql 语句
    SQLStatement sqlStatement = queryContext.getSqlStatementContext().getSqlStatement();
    // 解析分片条件
    ShardingConditions shardingConditions = createShardingConditions(queryContext, globalRuleMetaData, database, rule);
    // todo 校验 sql 语句
    Optional<ShardingStatementValidator> validator = ShardingStatementValidatorFactory.newInstance(sqlStatement, shardingConditions, globalRuleMetaData);
    validator.ifPresent(optional -> optional.preValidate(rule, queryContext.getSqlStatementContext(), queryContext.getParameters(), database, props));
    // 包含子查询需要合并分片条件
    if (sqlStatement instanceof DMLStatement && shardingConditions.isNeedMerge()) {
        // 合并查询条件
        shardingConditions.merge();
    }
    // 根据sql语句类型,路由规则等获取分片引擎,并分片
    RouteContext result = ShardingRouteEngineFactory.newInstance(rule, database, queryContext, shardingConditions, props, connectionContext, globalRuleMetaData)
            .route(rule);
    // todo 后置校验
    validator.ifPresent(optional -> optional.postValidate(rule, queryContext.getSqlStatementContext(), queryContext.getHintValueContext(), queryContext.getParameters(), database, props, result));
    return result;
}

以上是 ShardingSQLRouter 创建路由上下文的逻辑,主要包括:

  • 根据分片规则解析分表条件,包括分片表、分片字段、字段值。
  • SQL 语句校验。
  • 根据分片规则创建对应的分片路由引擎,并创建路由上下文。

下面看看最简单的分片规则路由引擎 (ShardingStandardRoutingEngine) 是怎么实现的:

// org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine#route
public RouteContext route(final ShardingRule shardingRule) {
    RouteContext result = new RouteContext();
    // ☆☆☆☆核心☆☆☆☆ 根据规则,获取数据节点
    Collection<DataNode> dataNodes = getDataNodes(shardingRule, shardingRule.getShardingTable(logicTableName));
    result.getOriginalDataNodes().addAll(originalDataNodes);
    for (DataNode each : dataNodes) {
        // 以路由单位进行封装,封装逻辑库与实际库映射,逻辑表与实际表映射
        result.getRouteUnits().add(
                new RouteUnit(new RouteMapper(each.getDataSourceName(), each.getDataSourceName()), Collections.singleton(new RouteMapper(logicTableName, each.getTableName()))));
    }
    return result;
}

// org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine#getDataNodes
private Collection<DataNode> getDataNodes(final ShardingRule shardingRule, final ShardingTable shardingTable) {
    // 根据规则获取数据库分片策略
    ShardingStrategy databaseShardingStrategy = createShardingStrategy(shardingRule.getDatabaseShardingStrategyConfiguration(shardingTable),
            shardingRule.getShardingAlgorithms(), shardingRule.getDefaultShardingColumn());
    // 根据规则获取表分片策略
    ShardingStrategy tableShardingStrategy = createShardingStrategy(shardingRule.getTableShardingStrategyConfiguration(shardingTable),
            shardingRule.getShardingAlgorithms(), shardingRule.getDefaultShardingColumn());
    // 强制路由
    if (isRoutingByHint(shardingRule, shardingTable)) {
        return routeByHint(shardingTable, databaseShardingStrategy, tableShardingStrategy);
    }
    // 条件路由
    if (isRoutingByShardingConditions(shardingRule, shardingTable)) {
        return routeByShardingConditions(shardingRule, shardingTable, databaseShardingStrategy, tableShardingStrategy);
    }
    // 混合强制路由和条件路由,库和表分片其中一个是强制路由
    return routeByMixedConditions(shardingRule, shardingTable, databaseShardingStrategy, tableShardingStrategy);
}

// org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine#routeByShardingConditionsWithCondition
private Collection<DataNode> routeByShardingConditionsWithCondition(final ShardingRule shardingRule, final ShardingTable shardingTable,
                                                                        final ShardingStrategy databaseShardingStrategy, final ShardingStrategy tableShardingStrategy) {
    Collection<DataNode> result = new LinkedList<>();
    // 根据分片条件路由,获取数据节点
    for (ShardingCondition each : shardingConditions.getConditions()) {
        Collection<DataNode> dataNodes = route0(shardingTable,
                databaseShardingStrategy, getShardingValuesFromShardingConditions(shardingRule, databaseShardingStrategy.getShardingColumns(), each),
                tableShardingStrategy, getShardingValuesFromShardingConditions(shardingRule, tableShardingStrategy.getShardingColumns(), each));
        result.addAll(dataNodes);
        originalDataNodes.add(dataNodes);
    }
    return result;
}

// org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine#route0
private Collection<DataNode> route0(final ShardingTable shardingTable,
                                        final ShardingStrategy databaseShardingStrategy, final List<ShardingConditionValue> databaseShardingValues,
                                        final ShardingStrategy tableShardingStrategy, final List<ShardingConditionValue> tableShardingValues) {
    // 路由到对应数据库
    Collection<String> routedDataSources = routeDataSources(shardingTable, databaseShardingStrategy, databaseShardingValues);
    Collection<DataNode> result = new LinkedList<>();
    // 路由到对应表
    for (String each : routedDataSources) {
        result.addAll(routeTables(shardingTable, each, tableShardingStrategy, tableShardingValues));
    }
    return result;
}

// org.apache.shardingsphere.sharding.route.strategy.type.standard.StandardShardingStrategy#doSharding
public Collection<String> doSharding(final Collection<String> availableTargetNames, final Collection<ShardingConditionValue> shardingConditionValues,
                                     final DataNodeInfo dataNodeInfo, final ConfigurationProperties props) {
    ShardingConditionValue shardingConditionValue = shardingConditionValues.iterator().next();
    // 根据分片列是精确查询还是范围查询调用对应的处理方法
    Collection<String> shardingResult = shardingConditionValue instanceof ListShardingConditionValue
            ? doSharding(availableTargetNames, (ListShardingConditionValue) shardingConditionValue, dataNodeInfo)
            : doSharding(availableTargetNames, (RangeShardingConditionValue) shardingConditionValue, dataNodeInfo);
    Collection<String> result = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
    result.addAll(shardingResult);
    return result;
}

// org.apache.shardingsphere.sharding.route.strategy.type.standard.StandardShardingStrategy#doSharding
private Collection<String> doSharding(final Collection<String> availableTargetNames, final ListShardingConditionValue<?> shardingValue, final DataNodeInfo dataNodeInfo) {
    Collection<String> result = new LinkedList<>();
    for (Object each : shardingValue.getValues()) {
        // 这里就是我们指定的或者自定义的分片算法
        String target = shardingAlgorithm.doSharding(availableTargetNames,
                new PreciseShardingValue(shardingValue.getTableName(), shardingValue.getColumnName(), dataNodeInfo, each));
        if (null != target && availableTargetNames.contains(target)) {
            result.add(target);
        } else if (null != target && !availableTargetNames.contains(target)) {
            throw new ShardingRouteAlgorithmException(target, availableTargetNames);
        }
    }
    return result;
}

接下来简单分析一下,上述代码是以分片条件进行路由:

  • 路由上下文主要是由一组路由单元组成,路由单元中主要包括库映射和表映射。
  • 路由主要包括库路由和表路由,其中包括强制路由、分片路由、混合强制路由(库强制标条件或者库条件表强制)。
  • 路由的可用数据节点 (DataNodes) 是根据表分片规则的 AcutalDataNode 解析的。
  • 根据我们分片规则指定的分片算法进行路由,路由主要包括等值条件路由和范围条件路由。
  • 如果 SQL 没有分片条件,则进行全库表路由。

接下来看看读写分离的实现:

// org.apache.shardingsphere.readwritesplitting.route.ReadwriteSplittingSQLRouter#createRouteContext
@Override
public RouteContext createRouteContext(final QueryContext queryContext, final RuleMetaData globalRuleMetaData,
                                       final ShardingSphereDatabase database, final ReadwriteSplittingRule rule, final ConfigurationProperties props, final ConnectionContext connectionContext) {
    // ReadwriteSplittingSQLRouter在所有SQLRouter中排最后一位,如果是创建路由上下文则说明当前SQL未匹配其他路由条件
    RouteContext result = new RouteContext();
    ReadwriteSplittingDataSourceRule singleDataSourceRule = rule.getSingleDataSourceRule();
    // 直接使用ReadwriteSplittingDataSourceRouter来进行路由
    String dataSourceName = new ReadwriteSplittingDataSourceRouter(singleDataSourceRule, connectionContext).route(queryContext.getSqlStatementContext(), queryContext.getHintValueContext());
    result.getRouteUnits().add(new RouteUnit(new RouteMapper(singleDataSourceRule.getName(), dataSourceName), Collections.emptyList()));
    return result;
}

// org.apache.shardingsphere.readwritesplitting.route.ReadwriteSplittingSQLRouter#decorateRouteContext
@Override
public void decorateRouteContext(final RouteContext routeContext, final QueryContext queryContext, final ShardingSphereDatabase database,
                                 final ReadwriteSplittingRule rule, final ConfigurationProperties props, final ConnectionContext connectionContext) {
    Collection<RouteUnit> toBeRemoved = new LinkedList<>();
    Collection<RouteUnit> toBeAdded = new LinkedList<>();
    for (RouteUnit each : routeContext.getRouteUnits()) {
        // 获取逻辑库对应的读写分离规则,如果逻辑库存在对应规则,并且路由单元的实际库等于读写分离规则名
        String dataSourceName = each.getDataSourceMapper().getLogicName();
        Optional<ReadwriteSplittingDataSourceRule> dataSourceRule = rule.findDataSourceRule(dataSourceName);
        if (dataSourceRule.isPresent() && dataSourceRule.get().getName().equalsIgnoreCase(each.getDataSourceMapper().getActualName())) {
            toBeRemoved.add(each);
            // 重新构建路由单元
            String actualDataSourceName = new ReadwriteSplittingDataSourceRouter(dataSourceRule.get(), connectionContext).route(queryContext.getSqlStatementContext(),
                    queryContext.getHintValueContext());
            toBeAdded.add(new RouteUnit(new RouteMapper(each.getDataSourceMapper().getLogicName(), actualDataSourceName), each.getTableMappers()));
        }
    }
    routeContext.getRouteUnits().removeAll(toBeRemoved);
    routeContext.getRouteUnits().addAll(toBeAdded);
}

// org.apache.shardingsphere.readwritesplitting.route.ReadwriteSplittingDataSourceRouter#route
public String route(final SQLStatementContext sqlStatementContext, final HintValueContext hintValueContext) {
    for (QualifiedReadwriteSplittingDataSourceRouter each : getQualifiedRouters(connectionContext)) {
        // 1、写操作、上锁、强制写 2、事务
        if (each.isQualified(sqlStatementContext, rule, hintValueContext)) {
            return each.route(rule);
        }
    }
    // 负载均衡到从库(过滤禁用从库)
    return new StandardReadwriteSplittingDataSourceRouter().route(rule);
}

// org.apache.shardingsphere.readwritesplitting.route.ReadwriteSplittingDataSourceRouter#getQualifiedRouters
private Collection<QualifiedReadwriteSplittingDataSourceRouter> getQualifiedRouters(final ConnectionContext connectionContext) {
    // 1、QualifiedReadwriteSplittingPrimaryDataSourceRouter 写操作、上锁、强制路由路由主库
    // 2、QualifiedReadwriteSplittingTransactionalDataSourceRouter 事务默认路由主库,FIXED、DYNAMIC策略读从库(过滤禁用的从库)
    return Arrays.asList(new QualifiedReadwriteSplittingPrimaryDataSourceRouter(), new QualifiedReadwriteSplittingTransactionalDataSourceRouter(connectionContext));
}
  • 读写分离路由器主要分为创建路由单元和装饰路由单元
  • 由于读写分离路由器在所有路由器中排在最后一位,如果是创建路由单元说明其他路由规则都没有匹配到,此时直接使用 ReadwriteSplittingDataSourceRouter 进行路由
  • 如果是装饰路由上下文,则遍历路由单元,如果匹配上读写分离规则则重新使用 ReadwriteSplittingDataSourceRouter 进行路由后覆盖原本的路由单元
  • ReadwriteSplittingDataSourceRouter 主要是根据以下两个条件判断是否读主库
    • 写操作、上锁、强制路由路由主库
    • 事务默认路由主库,FIXED、DYNAMIC策略读从库(过滤禁用的从库)
    • 其他情况都是负载均衡到从库(具体负载均衡规则是我们自己指定的)

改写引擎

// org.apache.shardingsphere.infra.rewrite.SQLRewriteEntry#rewrite
public SQLRewriteResult rewrite(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
    // ☆☆☆☆核心☆☆☆☆ 创建sql改写上下文,包括sqlToken的生成
    SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext, connectionContext);
    SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
    // 改写,根据是否有路由单元选择对应的改写引擎
    return routeContext.getRouteUnits().isEmpty()
            ? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, queryContext)
            : new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext, queryContext);
}

SQL 改写主要有两个部分组成:

  1. SQL 改写上下文的创建,包括 sqlToken 的生成,sqlToken 主要用于改写 SQL。
  2. 根据路由上下文和 SQL 改写上下文进行改写。

接下来先看一下 SQL 改写上下文的创建:

// org.apache.shardingsphere.infra.rewrite.SQLRewriteEntry#createSQLRewriteContext
private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
    HintValueContext hintValueContext = queryContext.getHintValueContext();
    SQLRewriteContext result = new SQLRewriteContext(database, queryContext.getSqlStatementContext(), queryContext.getSql(), queryContext.getParameters(), connectionContext, hintValueContext);
    // 装饰sql改写上下文
    decorate(decorators, result, routeContext, hintValueContext);
    // 生成 sqlToken,sqlToken 主要是针对sql语句的各个分段进行改写,例如表名、字段、字段值等等
    result.generateSQLTokens();
    return result;
}

// org.apache.shardingsphere.infra.rewrite.SQLRewriteEntry#decorate
private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext,
                          final RouteContext routeContext, final HintValueContext hintValueContext) {
    // 强制路由可以绕过改写
    if (hintValueContext.isSkipSQLRewrite()) {
        return;
    }
    for (Entry<ShardingSphereRule, SQLRewriteContextDecorator> entry : decorators.entrySet()) {
        // ☆☆☆☆核心☆☆☆☆ 基于装饰器扩展点,可以添加一些自己需要的sqlToken,用于扩展改写sql
        entry.getValue().decorate(entry.getKey(), props, sqlRewriteContext, routeContext);
    }
}

// org.apache.shardingsphere.sharding.rewrite.context.ShardingSQLRewriteContextDecorator#decorate
public void decorate(final ShardingRule shardingRule, final ConfigurationProperties props, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
    SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
    if (sqlStatementContext instanceof InsertStatementContext && !containsShardingTable(shardingRule, sqlStatementContext)) {
        return;
    }
    if (!sqlRewriteContext.getParameters().isEmpty()) {
        Collection<ParameterRewriter> parameterRewriters =
                new ShardingParameterRewriterBuilder(shardingRule, routeContext, sqlRewriteContext.getDatabase().getSchemas(), sqlStatementContext).getParameterRewriters();
        rewriteParameters(sqlRewriteContext, parameterRewriters);
    }
    // 添加sqlToken构建器
    sqlRewriteContext.addSQLTokenGenerators(new ShardingTokenGenerateBuilder(shardingRule, routeContext, sqlStatementContext).getSQLTokenGenerators());
}

// org.apache.shardingsphere.sharding.rewrite.token.ShardingTokenGenerateBuilder#getSQLTokenGenerators
public Collection<SQLTokenGenerator> getSQLTokenGenerators() {
        Collection<SQLTokenGenerator> result = new LinkedList<>();
    addSQLTokenGenerator(result, new TableTokenGenerator());
    addSQLTokenGenerator(result, new DistinctProjectionPrefixTokenGenerator());
    addSQLTokenGenerator(result, new ProjectionsTokenGenerator());
    addSQLTokenGenerator(result, new OrderByTokenGenerator());
    addSQLTokenGenerator(result, new AggregationDistinctTokenGenerator());
    addSQLTokenGenerator(result, new IndexTokenGenerator());
    addSQLTokenGenerator(result, new ConstraintTokenGenerator());
    addSQLTokenGenerator(result, new OffsetTokenGenerator());
    addSQLTokenGenerator(result, new RowCountTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyInsertColumnTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyForUseDefaultInsertColumnsTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyAssignmentTokenGenerator());
    addSQLTokenGenerator(result, new ShardingInsertValuesTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyInsertValuesTokenGenerator());
    addSQLTokenGenerator(result, new ShardingRemoveTokenGenerator());
    addSQLTokenGenerator(result, new CursorTokenGenerator());
    addSQLTokenGenerator(result, new FetchDirectionTokenGenerator());
    return result;
}

// org.apache.shardingsphere.sharding.rewrite.token.generator.impl.TableTokenGenerator#generateSQLTokens
private Collection<SQLToken> generateSQLTokens(final TableAvailable sqlStatementContext) {
    Collection<SQLToken> result = new LinkedList<>();
    for (SimpleTableSegment each : sqlStatementContext.getAllTables()) {
        TableNameSegment tableName = each.getTableName();
        if (shardingRule.findShardingTable(tableName.getIdentifier().getValue()).isPresent()) {
            result.add(new TableToken(tableName.getStartIndex(), tableName.getStopIndex(), tableName.getIdentifier(), (SQLStatementContext) sqlStatementContext, shardingRule));
        }
    }
    return result;
}

SQL 改写上下文的创建最主要就是创建 sqlToken,其中可以分为几个步骤:

  1. 根据规则获取 SQL 改写上下文装饰器,主要包括分片规则和数据加密装饰器。

    decorators = OrderedSPILoader.getServices(SQLRewriteContextDecorator.class, database.getRuleMetaData().getRules());
    
  2. 装饰器主要作用就是添加 sqlToken 构建器。

  3. 统一生成 sqlToken,sqlToken 主要是针对 SQL 语句的各个分段进行改写,例如表名、字段、字段值等等。以 TableToken 为例,TableToken 从 SQL 解析出来的 TableStatement 中获取到表名的起始结束位置,改写时就是利用字符串截取替换起始结束位置之间的字符串。

接下来看一下具体的改写逻辑实现:

 // org.apache.shardingsphere.infra.rewrite.engine.RouteSQLRewriteEngine#rewrite
 public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final QueryContext queryContext) {
    Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits = new LinkedHashMap<>(routeContext.getRouteUnits().size(), 1F);
    // 根据数据库分组
    for (Entry<String, Collection<RouteUnit>> entry : aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
        Collection<RouteUnit> routeUnits = entry.getValue();
        // 同一个库的查询语句,且不包含子查询、关联查询、排序、分页、锁
        if (isNeedAggregateRewrite(sqlRewriteContext.getSqlStatementContext(), routeUnits)) {
            // 用 union all 连接 sql
            sqlRewriteUnits.put(routeUnits.iterator().next(), createSQLRewriteUnit(sqlRewriteContext, routeContext, routeUnits));
        } else {
            // 改写sql后封装成sql改写单元
            addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
        }
    }
    // 翻译SQL,官方暂未提供实现
    return new RouteSQLRewriteResult(translate(queryContext, sqlRewriteUnits));
}

// org.apache.shardingsphere.infra.rewrite.engine.RouteSQLRewriteEngine#createSQLRewriteUnit
private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
    Collection<String> sql = new LinkedList<>();
    List<Object> params = new LinkedList<>();
    // 参数是否是$开头
    boolean containsDollarMarker = sqlRewriteContext.getSqlStatementContext() instanceof SelectStatementContext
            && ((SelectStatementContext) (sqlRewriteContext.getSqlStatementContext())).isContainsDollarParameterMarker();
    for (RouteUnit each : routeUnits) {
        // 移除sql中的;号
        sql.add(SQLUtils.trimSemicolon(new RouteSQLBuilder(sqlRewriteContext, each).toSQL()));
        if (containsDollarMarker && !params.isEmpty()) {
            continue;
        }
        params.addAll(getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each));
    }
    // 用 union all 连接 sql
    return new SQLRewriteUnit(String.join(" UNION ALL ", sql), params);
}

// org.apache.shardingsphere.infra.rewrite.engine.RouteSQLRewriteEngine#addSQLRewriteUnits
private void addSQLRewriteUnits(final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits, final SQLRewriteContext sqlRewriteContext,
                                final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
    for (RouteUnit each : routeUnits) {
        // ☆☆☆☆核心☆☆☆☆ RouteSQLBuilder.toSQL() 改写sql
        sqlRewriteUnits.put(each, new SQLRewriteUnit(new RouteSQLBuilder(sqlRewriteContext, each).toSQL(), getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each)));
    }
}

// org.apache.shardingsphere.infra.rewrite.sql.impl.AbstractSQLBuilder#toSQL
public final String toSQL() {
    if (context.getSqlTokens().isEmpty()) {
        return context.getSql();
    }
    // 根据sqlToken的startIndex进行排序
    Collections.sort(context.getSqlTokens());
    // 遍历sqlToken替换和组装sql字符串
    StringBuilder result = new StringBuilder();
    result.append(context.getSql(), 0, context.getSqlTokens().get(0).getStartIndex());
    for (SQLToken each : context.getSqlTokens()) {
        if (each instanceof ComposableSQLToken) {
            result.append(getComposableSQLTokenText((ComposableSQLToken) each));
        } else if (each instanceof SubstitutableColumnNameToken) {
            result.append(((SubstitutableColumnNameToken) each).toString(routeUnit));
        } else {
            // sqlToken本身是包含抽象语法树语法节点的数据,所以可以直接进行组装
            result.append(getSQLTokenText(each));
        }
        // 组装sqlToken之间的连接符
        result.append(getConjunctionText(each));
    }
    return result.toString();
}

具体的改写逻辑还是比较简单的,主要就是利用 sqlToken 组装 SQL,因为 sqlToken 本身在生成时就包含语法树的节点数据,所以只需要根据规则进行处理即可,例如分库分表、数据加密。

合并引擎

在讲合并引擎之前,先简单看一下具体执行查询的逻辑

// org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSpherePreparedStatement#executeQuery0
private List<QueryResult> executeQuery0(final ExecutionContext executionContext) throws SQLException {
   
    // 按每个数据源的连接数分组,一个连接下可以包含一个数据源下的多个sql
    ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
    return executor.getRegularExecutor().executeQuery(executionGroupContext, executionContext.getQueryContext(),
            new PreparedStatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
                    metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatement,
                    SQLExecutorExceptionHandler.isExceptionThrown()));
}

// org.apache.shardingsphere.infra.executor.sql.prepare.AbstractExecutionPrepareEngine#prepare
public final ExecutionGroupContext<T> prepare(final RouteContext routeContext, final Map<String, Integer> connectionOffsets, final Collection<ExecutionUnit> executionUnits,
                                                  final ExecutionGroupReportContext reportContext) throws SQLException {
    Collection<ExecutionGroup<T>> result = new LinkedList<>();
    // 按数据原分数
    for (Entry<String, List<ExecutionUnit>> entry : aggregateExecutionUnitGroups(executionUnits).entrySet()) {
        String dataSourceName = entry.getKey();
        // 按每个数据源的连接数拆分成多执行单元(sql)集合
        List<List<ExecutionUnit>> executionUnitGroups = group(entry.getValue());
        ConnectionMode connectionMode = maxConnectionsSizePerQuery < entry.getValue().size() ? ConnectionMode.CONNECTION_STRICTLY : ConnectionMode.MEMORY_STRICTLY;
        // 按每个数据源的连接数分组,一个连接下可以包含一个数据源下的多个sql
        result.addAll(group(dataSourceName, connectionOffsets.getOrDefault(dataSourceName, 0), executionUnitGroups, connectionMode));
    }
    // 可扩展的装饰器
    return decorate(routeContext, result, reportContext);
}

// org.apache.shardingsphere.infra.executor.kernel.ExecutorEngine#execute
public <I, O> List<O> execute(final ExecutionGroupContext<I> executionGroupContext,
                                  final ExecutorCallback<I, O> firstCallback, final ExecutorCallback<I, O> callback, final boolean serial) throws SQLException {
    if (executionGroupContext.getInputGroups().isEmpty()) {
        return Collections.emptyList();
    }
    return serial ? serialExecute(executionGroupContext.getInputGroups().iterator(), executionGroupContext.getReportContext().getProcessId(), firstCallback, callback)
            : parallelExecute(executionGroupContext.getInputGroups().iterator(), executionGroupContext.getReportContext().getProcessId(), firstCallback, callback);
}

执行逻辑主要包括两部分:

  1. 根据数据库和单库最大连接数 (maxConnectionsSizePerQuery) 对执行单元进行分组
  2. 根据分组后的执行单元并发执行,以求最大的执行效率(分布式事务串行执行)

接下来看看合并引擎相关逻辑:

// org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSpherePreparedStatement#mergeQuery
private MergedResult mergeQuery(final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
    // 合并引擎,包括分片合并、数据脱敏、数据加密
    MergeEngine mergeEngine = new MergeEngine(metaDataContexts.getMetaData().getDatabase(databaseName),
            metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
    return mergeEngine.merge(queryResults, sqlStatementContext);
}

// org.apache.shardingsphere.infra.merge.MergeEngine#merge
public MergedResult merge(final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
    // 分片结果合并
    Optional<MergedResult> mergedResult = executeMerge(queryResults, sqlStatementContext);
    // 结果处理,数据脱敏、数据解密
    Optional<MergedResult> result = mergedResult.isPresent() ? Optional.of(decorate(mergedResult.get(), sqlStatementContext)) : decorate(queryResults.get(0), sqlStatementContext);
    return result.orElseGet(() -> new TransparentMergedResult(queryResults.get(0)));
}

// org.apache.shardingsphere.infra.merge.MergeEngine#executeMerge
private Optional<MergedResult> executeMerge(final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
    for (Entry<ShardingSphereRule, ResultProcessEngine> entry : engines.entrySet()) {
        if (entry.getValue() instanceof ResultMergerEngine) {
            ResultMerger resultMerger = ((ResultMergerEngine) entry.getValue()).newInstance(database.getName(), database.getProtocolType(), entry.getKey(), props, sqlStatementContext);
            return Optional.of(resultMerger.merge(queryResults, sqlStatementContext, database, connectionContext));
        }
    }
    return Optional.empty();
}

// org.apache.shardingsphere.sharding.merge.dql.ShardingDQLResultMerger#build
private MergedResult build(final List<QueryResult> queryResults, final SelectStatementContext selectStatementContext,
                               final Map<String, Integer> columnLabelIndexMap, final ShardingSphereDatabase database) throws SQLException {
    String defaultSchemaName = new DatabaseTypeRegistry(selectStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName());
    ShardingSphereSchema schema = selectStatementContext.getTablesContext().getSchemaName()
            .map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchemaName));
    // 分组、字段是聚合函数,例如sum()、min()
    if (isNeedProcessGroupBy(selectStatementContext)) {
        return getGroupByMergedResult(queryResults, selectStatementContext, columnLabelIndexMap, schema);
    }
    // 去重 distinct
    if (isNeedProcessDistinctRow(selectStatementContext)) {
        setGroupByForDistinctRow(selectStatementContext);
        return getGroupByMergedResult(queryResults, selectStatementContext, columnLabelIndexMap, schema);
    }
    // 排序
    if (isNeedProcessOrderBy(selectStatementContext)) {
        return new OrderByStreamMergedResult(queryResults, selectStatementContext, schema);
    }
    return new IteratorStreamMergedResult(queryResults);
}

// org.apache.shardingsphere.infra.merge.MergeEngine#decorate
private MergedResult decorate(final MergedResult mergedResult, final SQLStatementContext sqlStatementContext) throws SQLException {
    MergedResult result = null;
    for (Entry<ShardingSphereRule, ResultProcessEngine> entry : engines.entrySet()) {
        if (entry.getValue() instanceof ResultDecoratorEngine) {
            // 获取结果集装饰器
            ResultDecorator resultDecorator = getResultDecorator(sqlStatementContext, entry);
            result = null == result ? resultDecorator.decorate(mergedResult, sqlStatementContext, entry.getKey()) : resultDecorator.decorate(result, sqlStatementContext, entry.getKey());
        }
    }
    return null == result ? mergedResult : result;
}

合并引擎主要分为两个部分:

  1. 查询结果合并,包括聚合、排序、去重、分页
  2. 查询结果值处理,包括数据脱敏、解密

注意,合并引擎实际只是返回上着做了封装的 MergedResult,具体的合并逻辑实际是在遍历获取结果数据时进行的,下面以排序结果合并为例子,看看合并具体是如何执行的:

// org.apache.shardingsphere.sharding.merge.dql.orderby.OrderByStreamMergedResult
public class OrderByStreamMergedResult extends StreamMergedResult {
    
    private final Collection<OrderByItem> orderByItems;
    
    @Getter(AccessLevel.PROTECTED)
    private final Queue<OrderByValue> orderByValuesQueue;
    
    @Getter(AccessLevel.PROTECTED)
    private boolean isFirstNext;
    
    public OrderByStreamMergedResult(final List<QueryResult> queryResults, final SelectStatementContext selectStatementContext, final ShardingSphereSchema schema) throws SQLException {
        orderByItems = selectStatementContext.getOrderByContext().getItems();
        orderByValuesQueue = new PriorityQueue<>(queryResults.size());
        orderResultSetsToQueue(queryResults, selectStatementContext, schema);
        isFirstNext = true;
    }
    
    private void orderResultSetsToQueue(final List<QueryResult> queryResults, final SelectStatementContext selectStatementContext, final ShardingSphereSchema schema) throws SQLException {
        for (QueryResult each : queryResults) {
            OrderByValue orderByValue = new OrderByValue(each, orderByItems, selectStatementContext, schema);
            if (orderByValue.next()) {
                orderByValuesQueue.offer(orderByValue);
            }
        }
        setCurrentQueryResult(orderByValuesQueue.isEmpty() ? queryResults.get(0) : orderByValuesQueue.peek().getQueryResult());
    }
    
    @Override
    public boolean next() throws SQLException {
        if (orderByValuesQueue.isEmpty()) {
            return false;
        }
        if (isFirstNext) {
            isFirstNext = false;
            return true;
        }
        // 取出排序优先队列中的第一个结果集
        OrderByValue firstOrderByValue = orderByValuesQueue.poll();
        // 如果第一个结果集中存在下一个元素,则将结果集下一个元素中的值复制到OrderByValue中,
        // 然后重新放回优先队列中,从而保证所有结果集的顺序是正确的
        if (firstOrderByValue.next()) {
            orderByValuesQueue.offer(firstOrderByValue);
        }
        // 优先队列为空则说明遍历到最后一个元素
        if (orderByValuesQueue.isEmpty()) {
            return false;
        }
        //取当前顺序最前的元素,但不移除元素
        setCurrentQueryResult(orderByValuesQueue.peek().getQueryResult());
        return true;
    }
}

// org.apache.shardingsphere.sharding.merge.dql.orderby.OrderByValue
public final class OrderByValue implements Comparable<OrderByValue> {
  
    public boolean next() throws SQLException {
        // 移动游标
        boolean result = queryResult.next();
        // 将当前游标的排序字段值赋值给orderValues
        orderValues = result ? getOrderValues() : Collections.emptyList();
        return result;
    }
    
    private List<Comparable<?>> getOrderValues() throws SQLException {
        List<Comparable<?>> result = new ArrayList<>(orderByItems.size());
        for (OrderByItem each : orderByItems) {
            Object value = queryResult.getValue(each.getIndex(), Object.class);
            ShardingSpherePreconditions.checkState(null == value || value instanceof Comparable, () -> new NotImplementComparableValueException("Order by", value));
            result.add((Comparable<?>) value);
        }
        return result;
    }
    
    @Override
    public int compareTo(final OrderByValue orderByValue) {
        int i = 0;
        // 对比多个游标的当前值
        for (OrderByItem each : orderByItems) {
            int result = CompareUtils.compareTo(orderValues.get(i), orderByValue.orderValues.get(i), each.getSegment().getOrderDirection(),
                    each.getSegment().getNullsOrderType(selectStatementContext.getDatabaseType()), orderValuesCaseSensitive.get(i));
            if (0 != result) {
                return result;
            }
            i++;
        }
        return 0;
    }
}

简单分析一下排序结果集合并的代码,假设一条查询语句

 select f_id, f_name from t_user order by f_id;

分片路由后需要查询四张表,那么最终就会产生四个结果集,以这个场景为例子看看结果是如何合并的:

  1. 将各个结果集封装成 OrderByValue,并将结果集的第一个行数据的排序字段赋值给 OrderByValue,OrderByValue 根据结果集赋值的排序字段对比进行排序。
  2. 将封装好的 OrderByValue 放入优先队列中,那么最初就是根绝每个结果集的第一个行数据的排序字段进行排序。
  3. 在执行 OrderByStreamMergedResult.next 时,实际上就是将优先队列中的第一个结果集取出并将它的游标移动到下一位,并重新将排序字段赋值给 OrderByValue 后放入优先队列中。
  4. 然后再返回优先队列中排在第一位的结果集,但不从队列移除,此时该结果集的游标所处位置表示的就是排序最前的行数据。
  5. 反复执行 3、4 步骤,直到遍历完所有结果集下的所有行数据,一次保证所有结果集数据的顺序。

以上就是排序结果集合并的具体逻辑,如果对其他结果集合并逻辑感兴趣的看看 MergedResult 的其他实现。

总结

最后再重新看一下这张流程图,可以很清晰的发现代码的设计完全与这张流程图一致,主要流程 SQL 解析 -> 分片路由 -> SQL 改写 -> SQL 执行 -> 结果合并。

注:由于文章篇幅有限,没有贴完整代码,并且分片路由算法、读写分离负载均衡、加密脱敏、事务相关都未详讲,有兴趣之后再单独拆一篇文章出来详讲。

参考:

shardingsphere.apache.org/document/cu…

juejin.cn/post/734316…