目标
书接上文,我们看到:当ShardingSphere尝试获取改写后的SQL时,有个频繁操作的对象:sqlToken,即:SQL标记。
如下:
public final String toSQL() {
if (context.getSqlTokens().isEmpty()) {
return context.getSql();
}
// 对sqlToken排序
Collections.sort(context.getSqlTokens());
StringBuilder result = new StringBuilder();
result.append(context.getSql(), 0, context.getSqlTokens().get(0).getStartIndex());
for (SQLToken each : context.getSqlTokens()) {
result.append(each instanceof ComposableSQLToken ? getComposableSQLTokenText((ComposableSQLToken) each) : getSQLTokenText(each));
result.append(getConjunctionText(each));
}
return result.toString();
}
可以看到,sqlToken从上下文中获取,那么sqlToken是如何生成的呢?带着这个疑问,我们先回溯到创建SQL改写上下文的代码。
源码分析
1. 创建上下文
private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> parameters, final SQLStatementContext<?> sqlStatementContext, final RouteContext routeContext) {
SQLRewriteContext result = new SQLRewriteContext(schema, sqlStatementContext, sql, parameters);
// 装饰器模式 创建SQL标记生成器
decorate(decorators, result, routeContext);
// 创建SQL标记,放到上下文里
result.generateSQLTokens();
return result;
}
1.1. 加载SQL标记生成器
private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
decorators.forEach((key, value) -> value.decorate(key, props, sqlRewriteContext, routeContext));
}
其中,装饰器在改写入口加载,预留了一个SPI实现口子
public SQLRewriteEntry(final ShardingSphereSchema schema, final ConfigurationProperties props, final Collection<ShardingSphereRule> rules) {
this.schema = schema;
this.props = props;
decorators = OrderedSPIRegistry.getRegisteredServices(SQLRewriteContextDecorator.class, rules);
}
支持3类装饰器: ShardingSQLRewriteContextDecorator、EncryptSQLRewriteContextDecorator、ShadowSQLRewriteContextDecorator
继续跟进decorate代码:
public void decorate(final ShardingRule shardingRule, final ConfigurationProperties props, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
if (routeContext.isFederated()) {
return;
}
for (ParameterRewriter each : new ShardingParameterRewriterBuilder(shardingRule, routeContext).getParameterRewriters(sqlRewriteContext.getSchema())) {
if (!sqlRewriteContext.getParameters().isEmpty() && each.isNeedRewrite(sqlRewriteContext.getSqlStatementContext())) {
// 重写参数
each.rewrite(sqlRewriteContext.getParameterBuilder(), sqlRewriteContext.getSqlStatementContext(), sqlRewriteContext.getParameters());
}
}
/**
* 1、由分片规则、路由上下文,创建分片标记生成器,再用它去获取对应的SQL标记生成器
* 2、拿到的SQL标记生成器,放到SQL改写上下文里
*/
sqlRewriteContext.addSQLTokenGenerators(new ShardingTokenGenerateBuilder(shardingRule, routeContext).getSQLTokenGenerators());
}
public Collection<SQLTokenGenerator> getSQLTokenGenerators() {
// 创建SQL标记生成器,放到容器里
Collection<SQLTokenGenerator> result = buildSQLTokenGenerators();
for (SQLTokenGenerator each : result) {
// set分片规则
if (each instanceof ShardingRuleAware) {
((ShardingRuleAware) each).setShardingRule(shardingRule);
}
// set路由上下文
if (each instanceof RouteContextAware) {
((RouteContextAware) each).setRouteContext(routeContext);
}
}
return result;
}
可以看到,好多类标记生成器...
private Collection<SQLTokenGenerator> buildSQLTokenGenerators() {
// 各类标记生成器
Collection<SQLTokenGenerator> result = new LinkedList<>();
// 添加的过程中,做过滤:过滤掉【单路径】的
addSQLTokenGenerator(result, new TableTokenGenerator());
addSQLTokenGenerator(result, new DistinctProjectionPrefixTokenGenerator());
addSQLTokenGenerator(result, new ProjectionsTokenGenerator());
addSQLTokenGenerator(result, new OrderByTokenGenerator());
addSQLTokenGenerator(result, new AggregationDistinctTokenGenerator());
addSQLTokenGenerator(result, new IndexTokenGenerator());
addSQLTokenGenerator(result, new ConstraintTokenGenerator());
addSQLTokenGenerator(result, new OffsetTokenGenerator());
addSQLTokenGenerator(result, new RowCountTokenGenerator());
addSQLTokenGenerator(result, new GeneratedKeyInsertColumnTokenGenerator());
addSQLTokenGenerator(result, new GeneratedKeyForUseDefaultInsertColumnsTokenGenerator());
addSQLTokenGenerator(result, new GeneratedKeyAssignmentTokenGenerator());
addSQLTokenGenerator(result, new ShardingInsertValuesTokenGenerator());
addSQLTokenGenerator(result, new GeneratedKeyInsertValuesTokenGenerator());
return result;
}
肝不动,肝不动...我们这次只挑tableToken分析。
1.2. 创建SQL标记
接下来,到了使用生成器创建SQL标记的阶段:
// 创建SQL标记,放到上下文里
result.generateSQLTokens();
public List<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext, final List<Object> parameters, final ShardingSphereSchema schema) {
List<SQLToken> result = new LinkedList<>();
for (SQLTokenGenerator each : sqlTokenGenerators) {
// 参数、视图
setUpSQLTokenGenerator(each, parameters, schema, result);
if (!each.isGenerateSQLToken(sqlStatementContext)) {
continue;
}
// 创建SQL标记
if (each instanceof OptionalSQLTokenGenerator) {
SQLToken sqlToken = ((OptionalSQLTokenGenerator) each).generateSQLToken(sqlStatementContext);
if (!result.contains(sqlToken)) {
result.add(sqlToken);
}
} else if (each instanceof CollectionSQLTokenGenerator) {
result.addAll(((CollectionSQLTokenGenerator) each).generateSQLTokens(sqlStatementContext));
}
}
return result;
}
generateSQLTokens这一步,不同的生成器的执行逻辑有所区别,我们看tableToken的:
public Collection<TableToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof TableAvailable ? generateSQLTokens((TableAvailable) sqlStatementContext) : Collections.emptyList();
}
private Collection<TableToken> generateSQLTokens(final TableAvailable sqlStatementContext) {
Collection<TableToken> result = new LinkedList<>();
for (SimpleTableSegment each : sqlStatementContext.getAllTables()) {
if (shardingRule.findTableRule(each.getTableName().getIdentifier().getValue()).isPresent()) {
// 给表名打上SQL标记
result.add(new TableToken(each.getTableName().getStartIndex(), each.getTableName().getStopIndex(), each, (SQLStatementContext) sqlStatementContext, shardingRule));
}
}
return result;
}
可以看到,在这一步里,取到了改写后的表名,并记录了起止坐标
2. 使用sqlToken
回到文章开头,使用sqlToken的地方,我们看到:
public final String toSQL() {
if (context.getSqlTokens().isEmpty()) {
return context.getSql();
}
// 对sqlToken排序
Collections.sort(context.getSqlTokens());
StringBuilder result = new StringBuilder();
// 拼接前半段SQL
result.append(context.getSql(), 0, context.getSqlTokens().get(0).getStartIndex());
for (SQLToken each : context.getSqlTokens()) {
// 拼接改写后的表名
result.append(each instanceof ComposableSQLToken ? getComposableSQLTokenText((ComposableSQLToken) each) : getSQLTokenText(each));
// 拼接后半段SQL
result.append(getConjunctionText(each));
}
return result.toString();
}
private String getComposableSQLTokenText(final ComposableSQLToken composableSQLToken) {
StringBuilder result = new StringBuilder();
for (SQLToken each : composableSQLToken.getSqlTokens()) {
result.append(getSQLTokenText(each));
result.append(getConjunctionText(each));
}
return result.toString();
}
protected String getSQLTokenText(final SQLToken sqlToken) {
if (sqlToken instanceof RouteUnitAware) {
return ((RouteUnitAware) sqlToken).toString(routeUnit);
}
return sqlToken.toString();
}
只看tableToken的toString()实现:
public String toString(final RouteUnit routeUnit) {
// 先尝试从路径信息里获取真实表名
String actualTableName = getLogicAndActualTables(routeUnit).get(tableName.getValue().toLowerCase());
// 没有,则说明表名没有分片规则
actualTableName = null == actualTableName ? tableName.getValue().toLowerCase() : actualTableName;
return tableName.getQuoteCharacter().wrap(actualTableName);
}
private Map<String, String> getLogicAndActualTables(final RouteUnit routeUnit) {
Collection<String> tableNames = sqlStatementContext.getTablesContext().getTableNames();
Map<String, String> result = new HashMap<>(tableNames.size(), 1);
for (RouteMapper each : routeUnit.getTableMappers()) {
// 从路径信息里取:逻辑表名,真实表名
result.put(each.getLogicName().toLowerCase(), each.getActualName());
// 有绑定表时,把绑定表也加进去
result.putAll(shardingRule.getLogicAndActualTablesFromBindingTable(routeUnit.getDataSourceMapper().getLogicName(), each.getLogicName(), each.getActualName(), tableNames));
}
return result;
}
这一步拿到表名,回到之前的getConjunctionText方法
private String getConjunctionText(final SQLToken sqlToken) {
return context.getSql().substring(getStartIndex(sqlToken), getStopIndex(sqlToken));
}
private int getStartIndex(final SQLToken sqlToken) {
int startIndex = sqlToken instanceof Substitutable ? ((Substitutable) sqlToken).getStopIndex() + 1 : sqlToken.getStartIndex();
return Math.min(startIndex, context.getSql().length());
}
private int getStopIndex(final SQLToken sqlToken) {
int currentSQLTokenIndex = context.getSqlTokens().indexOf(sqlToken);
return context.getSqlTokens().size() - 1 == currentSQLTokenIndex ? context.getSql().length() : context.getSqlTokens().get(currentSQLTokenIndex + 1).getStartIndex();
}