ShardingSphere源码分析5-改写引擎2

344 阅读2分钟

目标

书接上文,我们看到:当ShardingSphere尝试获取改写后的SQL时,有个频繁操作的对象:sqlToken,即:SQL标记。

如下:

public final String toSQL() {
    if (context.getSqlTokens().isEmpty()) {
        return context.getSql();
    }
    // 对sqlToken排序
    Collections.sort(context.getSqlTokens());
    StringBuilder result = new StringBuilder();
    result.append(context.getSql(), 0, context.getSqlTokens().get(0).getStartIndex());
    for (SQLToken each : context.getSqlTokens()) {
        result.append(each instanceof ComposableSQLToken ? getComposableSQLTokenText((ComposableSQLToken) each) : getSQLTokenText(each));
        result.append(getConjunctionText(each));
    }
    return result.toString();
}

可以看到,sqlToken从上下文中获取,那么sqlToken是如何生成的呢?带着这个疑问,我们先回溯到创建SQL改写上下文的代码。

源码分析

1. 创建上下文

private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> parameters, final SQLStatementContext<?> sqlStatementContext, final RouteContext routeContext) {
    SQLRewriteContext result = new SQLRewriteContext(schema, sqlStatementContext, sql, parameters);
    // 装饰器模式 创建SQL标记生成器
    decorate(decorators, result, routeContext);
    // 创建SQL标记,放到上下文里
    result.generateSQLTokens();
    return result;
}

1.1. 加载SQL标记生成器

private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
    decorators.forEach((key, value) -> value.decorate(key, props, sqlRewriteContext, routeContext));
}

其中,装饰器在改写入口加载,预留了一个SPI实现口子

public SQLRewriteEntry(final ShardingSphereSchema schema, final ConfigurationProperties props, final Collection<ShardingSphereRule> rules) {
    this.schema = schema;
    this.props = props;
    decorators = OrderedSPIRegistry.getRegisteredServices(SQLRewriteContextDecorator.class, rules);
}

支持3类装饰器: ShardingSQLRewriteContextDecorator、EncryptSQLRewriteContextDecorator、ShadowSQLRewriteContextDecorator

继续跟进decorate代码:

public void decorate(final ShardingRule shardingRule, final ConfigurationProperties props, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
    if (routeContext.isFederated()) {
        return;
    }
    for (ParameterRewriter each : new ShardingParameterRewriterBuilder(shardingRule, routeContext).getParameterRewriters(sqlRewriteContext.getSchema())) {
        if (!sqlRewriteContext.getParameters().isEmpty() && each.isNeedRewrite(sqlRewriteContext.getSqlStatementContext())) {
            // 重写参数
            each.rewrite(sqlRewriteContext.getParameterBuilder(), sqlRewriteContext.getSqlStatementContext(), sqlRewriteContext.getParameters());
        }
    }
    /**
     * 1、由分片规则、路由上下文,创建分片标记生成器,再用它去获取对应的SQL标记生成器
     * 2、拿到的SQL标记生成器,放到SQL改写上下文里
     */
    sqlRewriteContext.addSQLTokenGenerators(new ShardingTokenGenerateBuilder(shardingRule, routeContext).getSQLTokenGenerators());
}
public Collection<SQLTokenGenerator> getSQLTokenGenerators() {
    // 创建SQL标记生成器,放到容器里
    Collection<SQLTokenGenerator> result = buildSQLTokenGenerators();
    for (SQLTokenGenerator each : result) {
        // set分片规则
        if (each instanceof ShardingRuleAware) {
            ((ShardingRuleAware) each).setShardingRule(shardingRule);
        }
        // set路由上下文
        if (each instanceof RouteContextAware) {
            ((RouteContextAware) each).setRouteContext(routeContext);
        }
    }
    return result;
}

可以看到,好多类标记生成器...

private Collection<SQLTokenGenerator> buildSQLTokenGenerators() {
    // 各类标记生成器
    Collection<SQLTokenGenerator> result = new LinkedList<>();
    // 添加的过程中,做过滤:过滤掉【单路径】的
    addSQLTokenGenerator(result, new TableTokenGenerator());
    addSQLTokenGenerator(result, new DistinctProjectionPrefixTokenGenerator());
    addSQLTokenGenerator(result, new ProjectionsTokenGenerator());
    addSQLTokenGenerator(result, new OrderByTokenGenerator());
    addSQLTokenGenerator(result, new AggregationDistinctTokenGenerator());
    addSQLTokenGenerator(result, new IndexTokenGenerator());
    addSQLTokenGenerator(result, new ConstraintTokenGenerator());
    addSQLTokenGenerator(result, new OffsetTokenGenerator());
    addSQLTokenGenerator(result, new RowCountTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyInsertColumnTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyForUseDefaultInsertColumnsTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyAssignmentTokenGenerator());
    addSQLTokenGenerator(result, new ShardingInsertValuesTokenGenerator());
    addSQLTokenGenerator(result, new GeneratedKeyInsertValuesTokenGenerator());
    return result;
}

肝不动,肝不动...我们这次只挑tableToken分析。

1.2. 创建SQL标记

接下来,到了使用生成器创建SQL标记的阶段:

// 创建SQL标记,放到上下文里 
result.generateSQLTokens();
public List<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext, final List<Object> parameters, final ShardingSphereSchema schema) {
    List<SQLToken> result = new LinkedList<>();
    for (SQLTokenGenerator each : sqlTokenGenerators) {
        // 参数、视图
        setUpSQLTokenGenerator(each, parameters, schema, result);
        if (!each.isGenerateSQLToken(sqlStatementContext)) {
            continue;
        }
        // 创建SQL标记
        if (each instanceof OptionalSQLTokenGenerator) {
            SQLToken sqlToken = ((OptionalSQLTokenGenerator) each).generateSQLToken(sqlStatementContext);
            if (!result.contains(sqlToken)) {
                result.add(sqlToken);
            }
        } else if (each instanceof CollectionSQLTokenGenerator) {
            result.addAll(((CollectionSQLTokenGenerator) each).generateSQLTokens(sqlStatementContext));
        }
    }
    return result;
}

generateSQLTokens这一步,不同的生成器的执行逻辑有所区别,我们看tableToken的:

public Collection<TableToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
    return sqlStatementContext instanceof TableAvailable ? generateSQLTokens((TableAvailable) sqlStatementContext) : Collections.emptyList();
}

private Collection<TableToken> generateSQLTokens(final TableAvailable sqlStatementContext) {
    Collection<TableToken> result = new LinkedList<>();
    for (SimpleTableSegment each : sqlStatementContext.getAllTables()) {
        if (shardingRule.findTableRule(each.getTableName().getIdentifier().getValue()).isPresent()) {
            // 给表名打上SQL标记
            result.add(new TableToken(each.getTableName().getStartIndex(), each.getTableName().getStopIndex(), each, (SQLStatementContext) sqlStatementContext, shardingRule));
        }
    }
    return result;
}

可以看到,在这一步里,取到了改写后的表名,并记录了起止坐标

2. 使用sqlToken

回到文章开头,使用sqlToken的地方,我们看到:

public final String toSQL() {
    if (context.getSqlTokens().isEmpty()) {
        return context.getSql();
    }
    // 对sqlToken排序
    Collections.sort(context.getSqlTokens());
    StringBuilder result = new StringBuilder();
    // 拼接前半段SQL
    result.append(context.getSql(), 0, context.getSqlTokens().get(0).getStartIndex());
    for (SQLToken each : context.getSqlTokens()) {
        // 拼接改写后的表名
        result.append(each instanceof ComposableSQLToken ? getComposableSQLTokenText((ComposableSQLToken) each) : getSQLTokenText(each));
        // 拼接后半段SQL
        result.append(getConjunctionText(each));
    }
    return result.toString();
}
private String getComposableSQLTokenText(final ComposableSQLToken composableSQLToken) {
    StringBuilder result = new StringBuilder();
    for (SQLToken each : composableSQLToken.getSqlTokens()) {
        result.append(getSQLTokenText(each));
        result.append(getConjunctionText(each));
    }
    return result.toString();
}
protected String getSQLTokenText(final SQLToken sqlToken) {
    if (sqlToken instanceof RouteUnitAware) {
        return ((RouteUnitAware) sqlToken).toString(routeUnit);
    }
    return sqlToken.toString();
}

只看tableToken的toString()实现:

public String toString(final RouteUnit routeUnit) {
    // 先尝试从路径信息里获取真实表名
    String actualTableName = getLogicAndActualTables(routeUnit).get(tableName.getValue().toLowerCase());
    // 没有,则说明表名没有分片规则
    actualTableName = null == actualTableName ? tableName.getValue().toLowerCase() : actualTableName;
    return tableName.getQuoteCharacter().wrap(actualTableName);
}

private Map<String, String> getLogicAndActualTables(final RouteUnit routeUnit) {
    Collection<String> tableNames = sqlStatementContext.getTablesContext().getTableNames();
    Map<String, String> result = new HashMap<>(tableNames.size(), 1);
    for (RouteMapper each : routeUnit.getTableMappers()) {
        // 从路径信息里取:逻辑表名,真实表名
        result.put(each.getLogicName().toLowerCase(), each.getActualName());
        // 有绑定表时,把绑定表也加进去
        result.putAll(shardingRule.getLogicAndActualTablesFromBindingTable(routeUnit.getDataSourceMapper().getLogicName(), each.getLogicName(), each.getActualName(), tableNames));
    }
    return result;
}

这一步拿到表名,回到之前的getConjunctionText方法

private String getConjunctionText(final SQLToken sqlToken) {
    return context.getSql().substring(getStartIndex(sqlToken), getStopIndex(sqlToken));
}

private int getStartIndex(final SQLToken sqlToken) {
    int startIndex = sqlToken instanceof Substitutable ? ((Substitutable) sqlToken).getStopIndex() + 1 : sqlToken.getStartIndex();
    return Math.min(startIndex, context.getSql().length());
}

private int getStopIndex(final SQLToken sqlToken) {
    int currentSQLTokenIndex = context.getSqlTokens().indexOf(sqlToken);
    return context.getSqlTokens().size() - 1 == currentSQLTokenIndex ? context.getSql().length() : context.getSqlTokens().get(currentSQLTokenIndex + 1).getStartIndex();
}