SpringBoot整合sharding-jdbc源码剖析

870 阅读15分钟

SpringBoot整合sharding-jdbc源码解析

前言

在上一篇文章中,我们留下了一个问题:sharding-jdbc是通过什么方式使得开发人员在日常的开发中无感的呢。这个问题的答案我们首先可以探讨一下Mybatis的实现原理(上一篇文章是以Mybatis为例,所以这篇文章我们也以Mybatis举例),从百度上可以得到一句言简意赅的话:MyBatis的基本工作原理是先封装SQL,接着调用JDBC操作数据库,最后把数据库返回的表结果封装成Java类(请点击此处查看出处

那么JDBC又是什么呢,其实指的就是我们最开始操作数据库的JDBC规范,流程如下:

JDBC规范流程.png

sharding-jdbc就是通过重写JDBC规范,对外提供与之前一样的接口达到对开发人员无感的。当然口说无凭,我们还是来通过源码来验证一下

1、SpringBoot整合sharding-jdbc之SpringBootConfiguration对象

SpringBootConfiguration对象是将sharding-jdbc整合进springboot的配置类,代码如下:

@Configuration
@ComponentScan({"org.apache.shardingsphere.spring.boot.converter"})
@EnableConfigurationProperties({SpringBootShardingRuleConfigurationProperties.class, SpringBootMasterSlaveRuleConfigurationProperties.class, SpringBootEncryptRuleConfigurationProperties.class, SpringBootPropertiesConfigurationProperties.class, SpringBootShadowRuleConfigurationProperties.class})
@ConditionalOnProperty(
   prefix = "spring.shardingsphere",
   name = {"enabled"},
   havingValue = "true",
   matchIfMissing = true
)
@AutoConfigureBefore({DataSourceAutoConfiguration.class})
public class SpringBootConfiguration implements EnvironmentAware {
   private final SpringBootShardingRuleConfigurationProperties shardingRule;
   private final SpringBootMasterSlaveRuleConfigurationProperties masterSlaveRule;
   private final SpringBootEncryptRuleConfigurationProperties encryptRule;
   private final SpringBootShadowRuleConfigurationProperties shadowRule;
   private final SpringBootPropertiesConfigurationProperties props;
   private final Map<String, DataSource> dataSourceMap = new LinkedHashMap();
   private final String jndiName = "jndi-name";
​
   @Bean
   @Conditional({ShardingRuleCondition.class})
   public DataSource shardingDataSource() throws SQLException {
     //调用ShardingDataSourceFactory创建DataSource对象
       return ShardingDataSourceFactory.createDataSource(this.dataSourceMap, (newShardingRuleConfigurationYamlSwapper()).swap(this.shardingRule), this.props.getProps());
  }
​
   @Bean
   @Conditional({MasterSlaveRuleCondition.class})
   public DataSource masterSlaveDataSource() throws SQLException {
     //调用MasterSlaveDataSourceFactory创建DataSource对象
       return MasterSlaveDataSourceFactory.createDataSource(this.dataSourceMap, (newMasterSlaveRuleConfigurationYamlSwapper()).swap(this.masterSlaveRule), this.props.getProps());
  }
​
   @Bean
   @Conditional({EncryptRuleCondition.class})
   public DataSource encryptDataSource() throws SQLException {
     //调用EncryptDataSourceFactory创建DataSource对象
       return EncryptDataSourceFactory.createDataSource((DataSource)this.dataSourceMap.values().iterator().next(), (newEncryptRuleConfigurationYamlSwapper()).swap(this.encryptRule), this.props.getProps());
  }
​
   @Bean
   @Conditional({ShadowRuleCondition.class})
   public DataSource shadowDataSource() throws SQLException {
     //调用ShadowDataSourceFactory创建DataSource对象
       return ShadowDataSourceFactory.createDataSource(this.dataSourceMap, (newShadowRuleConfigurationYamlSwapper()).swap(this.shadowRule), this.props.getProps());
  }
​
   @Bean
   public ShardingTransactionTypeScanner shardingTransactionTypeScanner() {
       return new ShardingTransactionTypeScanner();
  }
​
   public final void setEnvironment(Environment environment) {
     //实现EnvironmentAware接口,重写setEnvironment方法,从配置中获取配置的数据源名称
       String prefix = "spring.shardingsphere.datasource.";
       Iterator var3 = this.getDataSourceNames(environment, prefix).iterator();
​
       while(var3.hasNext()) {
           String each = (String)var3.next();
​
           try {
             //根据配置的数据源名称逐个获取数据源对象
               this.dataSourceMap.put(each, this.getDataSource(environment, prefix, each));
          } catch (ReflectiveOperationException var6) {
               throw new ShardingSphereException("Can't find datasource type!", var6);
          } catch (NamingException var7) {
               throw new ShardingSphereException("Can't find JNDI datasource!", var7);
          }
      }
​
  }
​
   private List<String> getDataSourceNames(Environment environment, String prefix) {
       StandardEnvironment standardEnv = (StandardEnvironment)environment;
       standardEnv.setIgnoreUnresolvableNestedPlaceholders(true);
       return null == standardEnv.getProperty(prefix + "name") ? (new InlineExpressionParser(standardEnv.getProperty(prefix +"names"))).splitAndEvaluate() : Collections.singletonList(standardEnv.getProperty(prefix + "name"));
  }
​
   private DataSource getDataSource(Environment environment, String prefix, String dataSourceName) throwsReflectiveOperationException, NamingException {
       Map<String, Object> dataSourceProps = (Map)PropertyUtil.handle(environment, prefix + dataSourceName.trim(), Map.class);
       Preconditions.checkState(!dataSourceProps.isEmpty(), "Wrong datasource properties!");
       if (dataSourceProps.containsKey("jndi-name")) {
           return this.getJndiDataSource(dataSourceProps.get("jndi-name").toString());
      } else {
         //创建真实的数据源对象
           DataSource result = DataSourceUtil.getDataSource(dataSourceProps.get("type").toString(), dataSourceProps);
           DataSourcePropertiesSetterHolder.getDataSourcePropertiesSetterByType(dataSourceProps.get("type").toString()).ifPresent((dataSourcePropertiesSetter) -> {
               dataSourcePropertiesSetter.propertiesSet(environment, prefix, dataSourceName, result);
          });
           return result;
      }
  }
​
   private DataSource getJndiDataSource(String jndiName) throws NamingException {
       JndiObjectFactoryBean bean = new JndiObjectFactoryBean();
       bean.setResourceRef(true);
       bean.setJndiName(jndiName);
       bean.setProxyInterface(DataSource.class);
       bean.afterPropertiesSet();
       return (DataSource)bean.getObject();
  }
​
   @Generated
   public SpringBootConfiguration(SpringBootShardingRuleConfigurationProperties shardingRule, SpringBootMasterSlaveRuleConfigurationProperties masterSlaveRule, SpringBootEncryptRuleConfigurationProperties encryptRule, SpringBootShadowRuleConfigurationProperties shadowRule, SpringBootPropertiesConfigurationProperties props) {
       this.shardingRule = shardingRule;
       this.masterSlaveRule = masterSlaveRule;
       this.encryptRule = encryptRule;
       this.shadowRule = shadowRule;
       this.props = props;
  }
}
  1. 首先可以看到该类实现了EnvironmentAware接口,该接口是Spring提供的扩展接口,通过setEnvironment方法从配置中获取配置到的数据源,然后将对应的数据源放入到datasourceMap中
  2. getDataSource的作用是根据我们配置的数据源名称去获取真实的数据源
  3. 然后在该类中有几个@Bean的方法,作用就是根据不同的配置调用不同的工厂类返回不同的sharding-jdbc自己封装的DataSource对象(例如数据分片返回的就是ShardingDataSource对象,读写分离就是返回MasterSlaveDataSource对象),接下来我们以数据分片为例继续深入源码解析。

2、SpringBoot整合sharding-jdbc之ShardingDatasource对象

接下来我们以开源的4.1.1版本来进行讲解

@Getter
public class ShardingDataSource extends AbstractDataSourceAdapter {
   
   private final ShardingRuntimeContext runtimeContext;
   
   static {
     //SPI机制
       NewInstanceServiceLoader.register(RouteDecorator.class);
       NewInstanceServiceLoader.register(SQLRewriteContextDecorator.class);
       NewInstanceServiceLoader.register(ResultProcessEngine.class);
  }
   
   public ShardingDataSource(final Map<String, DataSource> dataSourceMap, final ShardingRule shardingRule, final Properties props) throws SQLException {
       super(dataSourceMap);
       checkDataSourceType(dataSourceMap);
       runtimeContext = new ShardingRuntimeContext(dataSourceMap, shardingRule, props, getDatabaseType());
  }
   
   private void checkDataSourceType(final Map<String, DataSource> dataSourceMap) {
       for (DataSource each : dataSourceMap.values()) {
           Preconditions.checkArgument(!(each instanceof MasterSlaveDataSource), "Initialized data sources can not be master-slave data sources.");
      }
  }
   
   @Override
   public final ShardingConnection getConnection() {
       return new ShardingConnection(getDataSourceMap(), runtimeContext, TransactionTypeHolder.get());
  }
}
  1. 首先我们查看一下ShardingDataSource的类继承结构图:

ShardingDataSource类继承结构.png

可以看到ShardingDataSource是实现了DataSource接口的,那么它就具备JDBC规范中的相应接口

2. 首先我们看一下该类中的静态代码块,使用了SPI机制进行注册,我们点击进入具体的方法

public static <T> void register(final Class<T> service) {
//通过ServiceLoader中的load方法将实现了service接口的类进行加载
   for (T each : ServiceLoader.load(service)) {
       registerServiceClass(service, each);
  }
}
​
@SuppressWarnings("unchecked")
private static <T> void registerServiceClass(final Class<T> service, final T instance) {
   Collection<Class<?>> serviceClasses = SERVICE_MAP.get(service);
//判断当前SERVICE_MAP中是否存在,如果不存在则进行初始化
   if (null == serviceClasses) {
       serviceClasses = new LinkedHashSet<>();
  }
//添加进list中
   serviceClasses.add(instance.getClass());
   SERVICE_MAP.put(service, serviceClasses);
}

查看上面的代码我们可以确认SERVICE_MAP的key是接口,value是存放了该接口实现类的hashSet。

我们可以看见静态方法利用SPI机制加载了RouteDecorator、SQLRewriteContextDecorator、ResultProcessEngine三个接口的实现类,我们可以分别查看一下对应接口的实现类

RouteDecorator:实现类有ShardingRouteDecorator、MasterSlaveRouteDecorator实现类

RouteDecorator实现类.png

SQLRewriteContextDecorator:ShardingSQLRewriteContextDecorator、ShadowSQLRewriteContextDecorator、EncryptSQLRewriteContextDecorator实现类

SQLRewriteContextDecorator实现类.png

ResultProcessEngine:ShardingResultMergeEngine、EncryptResultDecoratorEngine

ResultProcessEngine实现类.png

关于上述类的作用我们后面走到具体的流程再进行描述,接下来我们看一下ShardingDataSource的构造方法,可以看到主要做了三件事

//调用父类方法完成初始化
super(dataSourceMap);
//检查数据源类型是否符合分片数据库要求
checkDataSourceType(dataSourceMap);
//创建数据分片运行时上下文
runtimeContext = new ShardingRuntimeContext(dataSourceMap, shardingRule, props, getDatabaseType());

最后我们需要关注一下ShardingDataSource的getConnection方法:

@Override
public final ShardingConnection getConnection() {
//返回的对象时ShardingConnection对象
   return new ShardingConnection(getDataSourceMap(), runtimeContext, TransactionTypeHolder.get());
}

可以看到调用ShardingDataSource的getConnection对象最后返回的其实是由sharding-jdbc自己封装的ShardingConnection对象

3、SpringBoot整合sharding-jdbc之ShardingConnection对象

  1. 首先我们来看一下ShardingConnection对象的类继承图

    image-20220713232356208

    跟ShardingDataSource一样,ShardingConnection也是实现了Conenction接口,那么它也具有JDBC规范中Connection一样的接口了

  2. 首先我们看一下ShardingConnection的构造方法

    public ShardingConnection(final Map<String, DataSource> dataSourceMap, final ShardingRuntimeContextruntimeContext, final TransactionType transactionType) {
       this.dataSourceMap = dataSourceMap;
       this.runtimeContext = runtimeContext;
       this.transactionType = transactionType;
    //创建事务管理器
       shardingTransactionManager =runtimeContext.getShardingTransactionManagerEngine().getTransactionManager(transactionType);
    }
    

    可以看到主要是完成一些初始化的工作

  3. 接下来我们可以看一下创建Statement对象的方法

    @Override
    public PreparedStatement prepareStatement(final String sql) throws SQLException {
       return new ShardingPreparedStatement(this, sql);
    }
    ​
    @Override
    public PreparedStatement prepareStatement(final String sql, final int resultSetType, final int resultSetConcurrency) throws SQLException {
       return new ShardingPreparedStatement(this, sql, resultSetType, resultSetConcurrency);
    }
    ​
    @Override
    public PreparedStatement prepareStatement(final String sql, final int resultSetType, final int resultSetConcurrency, finalint resultSetHoldability) throws SQLException {
       return new ShardingPreparedStatement(this, sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    }
    ​
    @Override
    public PreparedStatement prepareStatement(final String sql, final int autoGeneratedKeys) throws SQLException {
       return new ShardingPreparedStatement(this, sql, autoGeneratedKeys);
    }
    ​
    @Override
    public PreparedStatement prepareStatement(final String sql, final int[] columnIndexes) throws SQLException {
       return new ShardingPreparedStatement(this, sql, Statement.RETURN_GENERATED_KEYS);
    }
    ​
    @Override
    public PreparedStatement prepareStatement(final String sql, final String[] columnNames) throws SQLException {
       return new ShardingPreparedStatement(this, sql, Statement.RETURN_GENERATED_KEYS);
    }
    ​
    @Override
    public Statement createStatement() {
       return new ShardingStatement(this);
    }
    ​
    @Override
    public Statement createStatement(final int resultSetType, final int resultSetConcurrency) {
       return new ShardingStatement(this, resultSetType, resultSetConcurrency);
    }
    ​
    @Override
    public Statement createStatement(final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
       return new ShardingStatement(this, resultSetType, resultSetConcurrency, resultSetHoldability);
    }
    

    ShardingConnection创建的Statement对象其实也是被sharding-jdbc封装的对象

    4、SpringBoot整合sharding-jdbc之ShardingPreparedStatement对象

    1. 同样的,首先我们查看一下ShardingPreparedStatment对象的类继承图

ShardingPreparedStatement对象类继承图.png 跟ShardingDataSource一样 不再赘述

2.查看ShardingPreparedStatement对象的真实的构造方法

private ShardingPreparedStatement(final ShardingConnection connection, final String sql,
                                     final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, finalboolean returnGeneratedKeys) throws SQLException {
       if (Strings.isNullOrEmpty(sql)) {
           throw new SQLException(SQLExceptionConstant.SQL_STRING_NULL_OR_EMPTY);
      }
       this.connection = connection;
       //此时的sql还是占位符形式
       this.sql = sql;
       ShardingRuntimeContext runtimeContext = connection.getRuntimeContext();
       parameterMetaData = new ShardingParameterMetaData(runtimeContext.getSqlParserEngine(), sql);
    //创建准备引擎 可以看到创建的类型是PreparedQueryPrepareEngine类型
       prepareEngine = new PreparedQueryPrepareEngine(runtimeContext.getRule().toRules(), runtimeContext.getProperties(), runtimeContext.getMetaData(), runtimeContext.getSqlParserEngine());
    //创建执行器
       preparedStatementExecutor = new PreparedStatementExecutor(resultSetType, resultSetConcurrency, resultSetHoldability, returnGeneratedKeys, connection);
       //创建批量执行器
    batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(resultSetType, resultSetConcurrency, resultSetHoldability, returnGeneratedKeys, connection);
    }
    ```

在构造方法中主要完成引擎对象以及执行器对象的创建

3.接下来我们主要分析ShardingPreparedStatement对象的executeUpdate方法(对应Insert语句),

    ```
    public int executeUpdate() throws SQLException {
       try {
           //原有数据的清理
           clearPrevious();
           //准备工作:完成SQL解析、SQL路由以及SQL重写
           prepare();
           //初始化执行器:进行真实数据源连接的建立以及Statement对象的创建
           initPreparedStatementExecutor();
           //真正开始执行
           return preparedStatementExecutor.executeUpdate();
      } finally {
           clearBatch();
      }
    }
    ```

第一步我们先进行原有缓存对象的清理,第二步我们进行准备工作,完成SQL解析,SQL路由以及SQL重写,第三步进行执行器的初始化,在这里就会进行真实数据源连接的建立以及Statement对象的创建,第四步开始真正执行数据插入。

4.clearPrevious()方法

    ```
    //先调用preparedStatementExecutor.clear()方法,如下为真实的clear方法
    public void clear() throws SQLException {
       clearStatements();
       statements.clear();
       parameterSets.clear();
       connections.clear();
       resultSets.clear();
       inputGroups.clear();
    }
    ```

主要做的就是清除工作

5.prepare方法

    ```
    private void prepare() {
       //生成执行上下文 prepareEngine的实际类型是PreparedQueryPrepareEngine,可以在构造方法里面看见
       executionContext = prepareEngine.prepare(sql, getParameters());
       findGeneratedKey().ifPresent(generatedKey ->generatedValues.add(generatedKey.getGeneratedValues().getLast()));
    }
    ```

    我们将关注点放在prepareEngine.prepare(sql, getParameters())方法中

    ```
    public ExecutionContext prepare(final String sql, final List<Object> parameters) {
       List<Object> clonedParameters = cloneParameters(parameters);
       //创建路由上下文
       RouteContext routeContext = executeRoute(sql, clonedParameters);
       //创建执行上下文对象
       ExecutionContext result = new ExecutionContext(routeContext.getSqlStatementContext());
       //执行重写
       result.getExecutionUnits().addAll(executeRewrite(sql, clonedParameters, routeContext));
       if (properties.<Boolean>getValue(ConfigurationPropertyKey.SQL_SHOW)) {
           SQLLogger.logSQL(sql, properties.<Boolean>getValue(ConfigurationPropertyKey.SQL_SIMPLE), result.getSqlStatementContext(), result.getExecutionUnits());
      }
       return result;
    }
    ```

第一步是调用executeRoute方法创建RouteContext对象,第二步创建ExecutionContext对象,第三步调用executeRewrite方法完成sql重写。

首先针对executeRoute方法进行分析

    ```
    private RouteContext executeRoute(final String sql, final List<Object> clonedParameters) {
       //利用SPI机制,不再进行说明
       registerRouteDecorator();
    //针对route方法进行分析,最终会走到PreparedQueryPrepareEngine的route方法
    //可以看到有一个router变量,这个变量是在什么时候进行创建的呢,可以在ShardingPreparedStatement的关于PrepareEngine的初始化中找到答案
       return route(router, sql, clonedParameters);
    }
    ```

    PreparedQueryPrepareEngine的route方法如下

    ```
    @Override
    protected RouteContext route(final DataNodeRouter dataNodeRouter, final String sql, final List<Object>parameters) {
       //调用dataNodeRouter的route方法进行路由
       return dataNodeRouter.route(sql, parameters, true);
    }
    ```

点击route方法进入

    ```
    public RouteContext route(final String sql, final List<Object> parameters, final boolean useCache) {
       //类似于责任链实现,可以实现RoutingHook接口进行拦截
       routingHook.start(sql);
       try {
           RouteContext result = executeRoute(sql, parameters, useCache);
           routingHook.finishSuccess(result, metaData.getSchema());
           return result;
           // CHECKSTYLE:OFF
      } catch (final Exception ex) {
           // CHECKSTYLE:ON
           routingHook.finishFailure(ex);
           throw ex;
      }
    }
    ​
    private RouteContext executeRoute(final String sql, final List<Object> parameters, final boolean useCache) {
    //创建路由上下文,见下方方法注释
           RouteContext result = createRouteContext(sql, parameters, useCache);
           for (Entry<BaseRule, RouteDecorator> entry : decorators.entrySet()) {
             //重点方法,其实就是从实现了RouteDecorator接口的实现类中挑选合适的类进行处理,我们可以选择ShardingRouteDecorator来进行分析
               result = entry.getValue().decorate(result, metaData, entry.getKey(), properties);
          }
           return result;
      }
    ​
    private RouteContext createRouteContext(final String sql, final List<Object> parameters, final boolean useCache) {
           SQLStatement sqlStatement = parserEngine.parse(sql, useCache);
           try {
               //创建SQLStatementContext对象 根据返回的Statement对象创建不同的Statement对象
               SQLStatementContext sqlStatementContext =SQLStatementContextFactory.newInstance(metaData.getSchema(), sql, parameters, sqlStatement);
               //创建RouteContext对象,此时的RouteContext还未完全进行初始化
               return new RouteContext(sqlStatementContext, parameters, new RouteResult());
               // TODO should pass parameters for master-slave
          } catch (final IndexOutOfBoundsException ex) {
               return new RouteContext(new CommonSQLStatementContext(sqlStatement), parameters, newRouteResult());
          }
      }
    ​
    /**
        * SharingRouteDecorator的重点方法
        */
    @Override
       public RouteContext decorate(final RouteContext routeContext, final ShardingSphereMetaData metaData, finalShardingRule shardingRule, final ConfigurationProperties properties) {
           SQLStatementContext sqlStatementContext = routeContext.getSqlStatementContext();
           List<Object> parameters = routeContext.getParameters();
           //进行校验
           ShardingStatementValidatorFactory.newInstance(
                   sqlStatementContext.getSqlStatement()).ifPresent(validator -> validator.validate(shardingRule, sqlStatementContext.getSqlStatement(), parameters));
           ShardingConditions shardingConditions = getShardingConditions(parameters, sqlStatementContext, metaData.getSchema(), shardingRule);
           boolean needMergeShardingValues = isNeedMergeShardingValues(sqlStatementContext, shardingRule);
           if (sqlStatementContext.getSqlStatement() instanceof DMLStatement && needMergeShardingValues) {
               checkSubqueryShardingValues(sqlStatementContext, shardingRule, shardingConditions);
               mergeShardingConditions(shardingConditions);
          }
         //创建ShardingRouteEngine对象
           ShardingRouteEngine shardingRouteEngine = ShardingRouteEngineFactory.newInstance(shardingRule, metaData, sqlStatementContext, shardingConditions, properties);
           //根据分片策略进行路由
           RouteResult routeResult = shardingRouteEngine.route(shardingRule);
           if (needMergeShardingValues) {
               Preconditions.checkState(1 == routeResult.getRouteUnits().size(), "Must have one sharding with subquery.");
          }
           return new RouteContext(sqlStatementContext, parameters, routeResult);
      }
    ```

我们可以看到最后是调用ShardingRouteEngine的route方法来进行路由,ShardingRouteEngine其实是一个接口,有多种不同的实现类,到底是哪一种实现类其实是在工厂类创建时根据不同类型的Statement对象来进行创建的,不再深入展开了,感兴趣的可以自行研究。

具体的实现类其实是ShardingStandardRoutingEngine,跟踪对应的route方法

    @Override
    public RouteResult route(final ShardingRule shardingRule) {
       if (isDMLForModify(sqlStatementContext) && 1 != ((TableAvailable) sqlStatementContext).getAllTables().size()) {
           throw new ShardingSphereException("Cannot support Multiple-Table for '%s'.", sqlStatementContext.getSqlStatement());
      }
    //调用generateResult方法,可以发现其中又调用了getDataNodes方法
       return generateRouteResult(getDataNodes(shardingRule, shardingRule.getTableRule(logicTableName)));
    }
    

我们针对getDataNode方法来进行分析

    /**
    * 获取数据节点
    * @param shardingRule
    * @param tableRule
    * @return
    */
    private Collection<DataNode> getDataNodes(final ShardingRule shardingRule, final TableRule tableRule) {
       //判断类型 强制路由还是进行分片 还是复杂条件
       if (isRoutingByHint(shardingRule, tableRule)) {
         //强制路由 不做深入讨论 可自行研究
           return routeByHint(shardingRule, tableRule);
      }
       if (isRoutingByShardingConditions(shardingRule, tableRule)) {
       //以标准流程为例进行深入
           return routeByShardingConditions(shardingRule, tableRule);
      }
       return routeByMixedConditions(shardingRule, tableRule);
    }
    ​
    private Collection<DataNode> routeByShardingConditions(final ShardingRule shardingRule, final TableRuletableRule) {
           //当shardingConditions不为空时,执行routeByShardingConditionsWithCondition
           return shardingConditions.getConditions().isEmpty()
                   ? route0(shardingRule, tableRule, Collections.emptyList(), Collections.emptyList()) : routeByShardingConditionsWithCondition(shardingRule, tableRule);
      }
    

routeByShardingConditions方法,需要进行重点关注,可以看到根据不同的条件走不同的逻辑,当我们进行数据分片的时候其实是会产生ShardingCondition对象的,我们可以继续跟routeByShardingConditionsWithCondition方法

    /**
    * 根据ShardingCondition进行路由
    * @param shardingRule
    * @param tableRule
    * @return
    */
    private Collection<DataNode> routeByShardingConditionsWithCondition(final ShardingRule shardingRule, finalTableRule tableRule) {
       Collection<DataNode> result = new LinkedList<>();
       for (ShardingCondition each : shardingConditions.getConditions()) {
           Collection<DataNode> dataNodes = route0(shardingRule, tableRule, 
                   getShardingValuesFromShardingConditions(shardingRule, shardingRule.getDatabaseShardingStrategy(tableRule).getShardingColumns(), each),
                   getShardingValuesFromShardingConditions(shardingRule, shardingRule.getTableShardingStrategy(tableRule).getShardingColumns(), each));
           result.addAll(dataNodes);
           originalDataNodes.add(dataNodes);
      }
       return result;
    }
    ​
    private Collection<DataNode> route0(final ShardingRule shardingRule, final TableRule tableRule, finalList<RouteValue> databaseShardingValues, final List<RouteValue> tableShardingValues) {
           //获取数据库节点
           Collection<String> routedDataSources = routeDataSources(shardingRule, tableRule, databaseShardingValues);
           Collection<DataNode> result = new LinkedList<>();
           for (String each : routedDataSources) {
               //根据数据库获取表节点
               result.addAll(routeTables(shardingRule, tableRule, each, tableShardingValues));
          }
           return result;
      }
    

其实最后还是走到了route0方法来进行操作,可以看到route0方法主要有两个步骤:

  1. 根据数据库分片条件获取真实的数据源
  2. 再根据得到的真实数据源逐一遍历获取真实的表

我们以获取真实的表为例进行讲解

    private Collection<DataNoderouteTables(final ShardingRule shardingRule, final TableRule tableRule, final StringroutedDataSource, final List<RouteValue> tableShardingValues) {
       Collection<String> availableTargetTables = tableRule.getActualTableNames(routedDataSource);
    //其实是调用分片策略来进行操作,以行表达式分片策略为例
       Collection<String> routedTables = new LinkedHashSet<>(tableShardingValues.isEmpty() ? availableTargetTables
              : shardingRule.getTableShardingStrategy(tableRule).doSharding(availableTargetTables, tableShardingValues, this.properties));
       Preconditions.checkState(!routedTables.isEmpty(), "no table route info");
       Collection<DataNode> result = new LinkedList<>();
       for (String each : routedTables) {
           result.add(new DataNode(routedDataSource, each));
      }
       return result;
    }
    /**
        * 执行分片
        * @param availableTargetNames available data sources or tables's names,可用的数据库表,其实就是所有的逻辑表
        * @param shardingValues sharding values
        * @param properties ShardingSphere properties
        * @return
        */
    @Override
       public Collection<StringdoSharding(final Collection<String> availableTargetNames, finalCollection<RouteValue> shardingValues, final ConfigurationProperties properties) {
           RouteValue shardingValue = shardingValues.iterator().next();
           if (properties.<Boolean>getValue(ConfigurationPropertyKey.ALLOW_RANGE_QUERY_WITH_INLINE_SHARDING) &&shardingValue instanceof RangeRouteValue) {
               return availableTargetNames;
          }
           Preconditions.checkState(shardingValue instanceof ListRouteValue"Inline strategy cannot support this type sharding:" + shardingValue.toString());
           Collection<String> shardingResult = doSharding((ListRouteValue) shardingValue);
           Collection<String> result = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
           for (String each : shardingResult) {
               //根据存在的真实的数据节点进行判断
               if (availableTargetNames.contains(each)) {
                   result.add(each);
              }
          }
           return result;
      }
    

在上述代码中可以看到有一个DataNode的类,我们来看一下它的实现:

    @RequiredArgsConstructor
    @Getter
    @ToString
    public final class DataNode {
       
       private static final String DELIMITER = ".";
       
    //真实的数据源名称,就是datasourceMap中的key,此时的dataSourceName已经是经过分片规则后获取的需要访问的数据源所对应的数据源名称了
       private final String dataSourceName;
       
    //逻辑表名称
       private final String tableName;
       
       /**
        * Constructs a data node with well-formatted string.
        *
        * @param dataNode string of data node. use {@code .} to split data source name and table name.
        */
       public DataNode(final String dataNode) {
           if (!isValidDataNode(dataNode)) {
               throw new ShardingSphereConfigurationException("Invalid format for actual data nodes: '%s'", dataNode);
          }
           List<String> segments = Splitter.on(DELIMITER).splitToList(dataNode);
           dataSourceName = segments.get(0);
           tableName = segments.get(1);
      }
       
       private static boolean isValidDataNode(final String dataNodeStr) {
           return dataNodeStr.contains(DELIMITER) && 2 == Splitter.on(DELIMITER).splitToList(dataNodeStr).size();
      }
       
       @Override
       public boolean equals(final Object object) {
           if (this == object) {
               return true;
          }
           if (null == object || getClass() != object.getClass()) {
               return false;
          }
           DataNode dataNode = (DataNode) object;
           return Objects.equal(this.dataSourceName.toUpperCase(), dataNode.dataSourceName.toUpperCase())
               && Objects.equal(this.tableName.toUpperCase(), dataNode.tableName.toUpperCase());
      }
       
       @Override
       public int hashCode() {
           return Objects.hashCode(dataSourceName.toUpperCase(), tableName.toUpperCase());
      }
    }
    

至此其实我们已经完成了executeRoute方法的分析,接下来的步骤就是进行执行上下文的创建与SQL重写,这一块的内容就留给各位自行研究一下。

  1. initPreparedStatementExecutor方法,这个方法里面主要做的事情包含真实数据源连接的建立以及真实Statement对象的建立

    private void initPreparedStatementExecutor() throws SQLException {
       preparedStatementExecutor.init(executionContext);
       //为PreparedStatement对象参数赋实际的值
       setParametersForStatements();
       replayMethodForStatements();
    }
    

    我们重点关注的对象应该放在preparedStatementExecutor.init(executionContext) ;

    public void init(final ExecutionContext executionContext) throws SQLException {
       setSqlStatementContext(executionContext.getSqlStatementContext());
       getInputGroups().addAll(obtainExecuteGroups(executionContext.getExecutionUnits()));
       cacheStatements();
    }
    
    1. 第一步设置当前SqlStatement上下文
    2. 第二步可以看到是给一个内部的集合添加一些所谓的执行组
    3. 缓存部分对象

    我们针对第二步来进行分析点击进入obtainExecuteGroups方法

    private Collection<InputGroup<StatementExecuteUnit>> obtainExecuteGroups(final Collection<ExecutionUnit>executionUnits) throws SQLException {
    //可以看到他在这里又继续调用getSqlExecutePrepareTemplate().getExecuteUnitGroups方法来进行处理,再次进入
       return getSqlExecutePrepareTemplate().getExecuteUnitGroups(executionUnits, newSQLExecutePrepareCallback() {
    ​
           /**
            * 根据数据库名称获取真实的数据库连接
            * @param connectionMode connection mode
            * @param dataSourceName data source name
            * @param connectionSize connection size
            * @return
            * @throws SQLException
            */
           @Override
           public List<Connection> getConnections(final ConnectionMode connectionMode, final StringdataSourceName, final int connectionSize) throws SQLException {
             //这里的方法就是真实的获取数据源连接的地方,很重要,可以看到有一个参数是dataSourceName,我们猜测这个dataSourceName就是最开始提到的datasourceMap中的key,如果是的话,又是什么时候调用到这个方法呢,我们接着往下看
               return PreparedStatementExecutor.super.getConnection().getConnections(connectionMode, dataSourceName, connectionSize);
          }
    ​
           /**
            * 创建JDBC执行联合对象
            * @param connection connection
            * @param executionUnit execution unit
            * @param connectionMode connection mode
            * @return
            * @throws SQLException
            */
           @Override
           public StatementExecuteUnit createStatementExecuteUnit(final Connection connection, final ExecutionUnitexecutionUnit, final ConnectionMode connectionMode) throws SQLException {
               return new StatementExecuteUnit(executionUnit, createPreparedStatement(connection, executionUnit.getSqlUnit().getSql()), connectionMode);
          }
      });
    }
    ​
    private Collection<InputGroup<StatementExecuteUnit>> getSynchronizedExecuteUnitGroups(
               final Collection<ExecutionUnit> executionUnits, final SQLExecutePrepareCallback callback) throwsSQLException {
    //
           Map<String, List<SQLUnit>> sqlUnitGroups = getSQLUnitGroups(executionUnits);
           Collection<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
           for (Entry<String, List<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
               result.addAll(getSQLExecuteGroups(entry.getKey(), entry.getValue(), callback));
          }
           return result;
      }
    

    可以看到它调用了getExecuteUnitGroups方法

    public Collection<InputGroup<StatementExecuteUnit>> getExecuteUnitGroups(final Collection<ExecutionUnit>executionUnits, final SQLExecutePrepareCallback callback) throws SQLException {
     //请对这里的callback对象着重关注,再次进入getSynchronizedExecuteUnitGroups方法
       return getSynchronizedExecuteUnitGroups(executionUnits, callback);
    }
    

    接下来我们再进入getSynchronizedExecuteUnitGroups方法中一探究竟

    private Collection<InputGroup<StatementExecuteUnit>> getSynchronizedExecuteUnitGroups(
           final Collection<ExecutionUnit> executionUnits, final SQLExecutePrepareCallback callback) throwsSQLException {
    //调用getSQLUnitGroups获取SQLUnit对象
       Map<StringList<SQLUnit>> sqlUnitGroups = getSQLUnitGroups(executionUnits);
       Collection<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
       for (Entry<StringList<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
           result.addAll(getSQLExecuteGroups(entry.getKey(), entry.getValue(), callback));
      }
       return result;
    }
    

首先调用getSQLUnitGroups方法获取SQLUnit对象,类结构如下,其实就是sql+参数

@RequiredArgsConstructor
@Getter
@EqualsAndHashCode(of = { "sql" })
@ToString
public final class SQLUnit {
   
   private final String sql;
   
   private final List<Objectparameters;
}

接下来我们深入看一下getSQLUnitGroups方法

private Map<String, List<SQLUnit>> getSQLUnitGroups(final Collection<ExecutionUnit> executionUnits) {
   Map<String, List<SQLUnit>> result = new LinkedHashMap<>(executionUnits.size(), 1);
   for (ExecutionUnit each : executionUnits) {
       if (!result.containsKey(each.getDataSourceName())) {
           result.put(each.getDataSourceName(), new LinkedList<>());
      }
       result.get(each.getDataSourceName()).add(each.getSqlUnit());
  }
   return result;
}

其实很简单,就是生成一个map,key是数据源名称,value是一个存放SQLUnit的List,还记得executionUnits是什么吗(需要自行观看源码查找答案),给一个线索:从ExecutionContext上下功夫,寻找ExecutionContext的创建时机。

接下来我们对getSynchronizedExecuteUnitGroups方法中的

Collection<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
for (Entry<String, List<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
//通过上述分析我们知道entry.getKey()其实是数据源名称,entry.getValue()其实是SQLUnit的集合
   result.addAll(getSQLExecuteGroups(entry.getKey(), entry.getValue(), callback));
}

我们再进入getSQLExecuteGroups查看对应逻辑

private List<InputGroup<StatementExecuteUnit>> getSQLExecuteGroups(final String dataSourceName,
                                                                  final List<SQLUnit> sqlUnits, final SQLExecutePrepareCallback callback) throwsSQLException {
   List<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
   int desiredPartitionSize = Math.max(0 == sqlUnits.size() % maxConnectionsSizePerQuery ? sqlUnits.size() /maxConnectionsSizePerQuery : sqlUnits.size() / maxConnectionsSizePerQuery + 11);
   List<List<SQLUnit>> sqlUnitPartitions = Lists.partition(sqlUnits, desiredPartitionSize);
   ConnectionMode connectionMode = maxConnectionsSizePerQuery < sqlUnits.size() ? ConnectionMode.CONNECTION_STRICTLY : ConnectionMode.MEMORY_STRICTLY;
   //调用callback的getConnections方法进行真正的连接的创建
   List<Connection> connections = callback.getConnections(connectionMode, dataSourceName, sqlUnitPartitions.size());
   int count = 0;
   for (List<SQLUnit> each : sqlUnitPartitions) {
       result.add(getSQLExecuteGroup(connectionMode, connections.get(count++), dataSourceName, each, callback));
  }
   return result;
}

在上述方法中我们可以看到调用了callback的getConnection方法进行真实数据源连接的创建,关于具体getConnections方法的实现还请各位同学自行研究

后面可以看到调用getSQLExecuteGroup方法做一些处理,我们再次深入

private InputGroup<StatementExecuteUnit> getSQLExecuteGroup(final ConnectionMode connectionMode, final Connection connection,
                                                           final String dataSourceName, final List<SQLUnit> sqlUnitGroup, final SQLExecutePrepareCallbackcallback) throws SQLException {
   List<StatementExecuteUnit> result = new LinkedList<>();
   for (SQLUnit each : sqlUnitGroup) {
       //调用callback的createStatementExecuteUnit方法进行创建
       result.add(callback.createStatementExecuteUnit(connection, new ExecutionUnit(dataSourceName, each), connectionMode));
  }
   return new InputGroup<>(result);
}

可以看到这里再次调用了callback中的createStatementExecuteUnit进行StatementExecuteUnit类的创建,我们看一下StatementExecuteUnit的结构

@RequiredArgsConstructor
@Getter
public final class StatementExecuteUnit {
   //数据源名称+sqlunit,可以参照下方结构
   private final ExecutionUnit executionUnit;
   
//操作数据库的Statement对象
   private final Statement statement;
   
   private final ConnectionMode connectionMode;
}
@RequiredArgsConstructor
@Getter
@EqualsAndHashCode
@ToString
public final class ExecutionUnit {
   
//真实的数据源名称
   private final String dataSourceName;
   
//sql+参数
   private final SQLUnit sqlUnit;
}

至此我们就完成initPreparedStatementExecutor方法的主要逻辑的梳理,接下来我们就来进行最后执行逻辑的梳理

7.SQL开始执行的真正方法:executeUpdate

/**
* Execute update.
* 
* @return effected records count
* @throws SQLException SQL exception
*/
public int executeUpdate() throws SQLException {
   final boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
   SQLExecuteCallback<Integer> executeCallback =SQLExecuteCallbackFactory.getPreparedUpdateSQLExecuteCallback(getDatabaseType(), isExceptionThrown);
   List<Integer> results = executeCallback(executeCallback);
   if (isAccumulate()) {
       return accumulate(results);
  } else {
       return results.get(0);
  }
}

主要逻辑就两步,调用SQLExecuteCallbackFactory生成executeCallback对象,第二步就执行executeCallback方法

protected final <T> List<T> executeCallback(final SQLExecuteCallback<T> executeCallback) throws SQLException {
   List<T> result = sqlExecuteTemplate.execute((Collection) inputGroups, executeCallback);
   refreshMetaDataIfNeeded(connection.getRuntimeContext(), sqlStatementContext);
   return result;
}

请注意这里有个inputGroups参数,就是刚才在init方法中生成的,再继续深入会进入到ExecutorEngine的execute方法,中间有一些比较简单的逻辑直接略过了

public <I, O> List<O> execute(final Collection<InputGroup<I>> inputGroups, 
                            final GroupedCallback<I, O> firstCallback, final GroupedCallback<I, O> callback, final boolean serial) throws SQLException {
  if (inputGroups.isEmpty()) {
      return Collections.emptyList();
  }
  //并行执行还是串行执行 --> 当sql需要进行跨库跨表查询时将进行并行执行最后再归并结果进行排序
  return serial ? serialExecute(inputGroups, firstCallback, callback) : parallelExecute(inputGroups, firstCallback, callback);
}

可以看到在这个方法中存在是并行执行还是串行执行逻辑,为什么会存在两种不同的执行方式请各位自行思考。我们以searialExecute为例

private <IO> List<O> serialExecute(final Collection<InputGroup<I>> inputGroups, final GroupedCallback<IO> firstCallback, finalGroupedCallback<IO> callback) throws SQLException {
   Iterator<InputGroup<I>> inputGroupsIterator = inputGroups.iterator();
   InputGroup<I> firstInputs = inputGroupsIterator.next();
   //调用syncExecute方法进行执行
   List<O> result = new LinkedList<>(syncExecute(firstInputs, null == firstCallback ? callback : firstCallback));
   for (InputGroup<I> each : Lists.newArrayList(inputGroupsIterator)) {
       result.addAll(syncExecute(each, callback));
  }
   return result;
}

可以看到就是遍历我们传进来的inputgroup调用syncExecute方法进行执行

private <IO> Collection<O> syncExecute(final InputGroup<I> inputGroup, final GroupedCallback<IO> callback) throwsSQLException {
   //可以看到最后执行的方法其实是callback的execute方法
   return callback.execute(inputGroup.getInputs(), trueExecutorDataMap.getValue());
}

最终其实是调用callback的execute方法进行执行,真实的逻辑会走到SQLExecuteCallBack的execute方法:

@Override
public final Collection<T> execute(final Collection<StatementExecuteUnit> statementExecuteUnits, 
                                  final boolean isTrunkThread, final Map<String, Object> dataMap) throws SQLException {
   Collection<T> result = new LinkedList<>();
   for (StatementExecuteUnit each : statementExecuteUnits) {
       result.add(execute0(each, isTrunkThread, dataMap));
  }
   return result;
}

可以看到我们前面提到的StatementExecuteUnit起到了作用,点击进入

private T execute0(final StatementExecuteUnit statementExecuteUnit, final boolean isTrunkThread, final Map<String, Object> dataMap) throws SQLException {
   ExecutorExceptionHandler.setExceptionThrown(isExceptionThrown);
   DataSourceMetaData dataSourceMetaData =getDataSourceMetaData(statementExecuteUnit.getStatement().getConnection().getMetaData());
   SQLExecutionHook sqlExecutionHook = new SPISQLExecutionHook();
   try {
     //获取ExecutionUnit对象
       ExecutionUnit executionUnit = statementExecuteUnit.getExecutionUnit();
       sqlExecutionHook.start(executionUnit.getDataSourceName(), executionUnit.getSqlUnit().getSql(), executionUnit.getSqlUnit().getParameters(), dataSourceMetaData, isTrunkThread, dataMap);
       T result = executeSQL(executionUnit.getSqlUnit().getSql(), statementExecuteUnit.getStatement(), statementExecuteUnit.getConnectionMode());
       sqlExecutionHook.finishSuccess();
       return result;
  } catch (final SQLException ex) {
       sqlExecutionHook.finishFailure(ex);
       ExecutorExceptionHandler.handleException(ex);
       return null;
  }
}

在这个方法中有一个重要的方法就是executeSQL方法,可以看到入参有sql语句,真实的Statement对象以及连接模式(作用可忽略)接下来就是真实的调用Statement对象进行SQL执行的步骤了。

关于后续的结果集处理这里就不再进行深入展开了。

至此咱们关于sharding-jdbc的整体sql的执行流程就结束了。

写在最后

  1. 这篇文章主要的目的是针对SprinBoot整合sharding-jdbc的原理进行了介绍,篇幅较长,但是其中还是有很多点没有讲到,各位如果感兴趣可以继续深入研究
  2. 学无止境,望各位与我共同进步