Spring 使用AbstractRoutingDataSource实现动态数据源

535 阅读2分钟
  • 我的需求是:一个区域对应一个数据源,根据登录用户的区域进行数据源切换。


  • 创建默认数据源

<bean id="dataSource" class="com.alibaba.druid.pool.DruidDataSource" init-method="init" destroy-method="close">   
  <property name="driverClassName" value="${jdbc.driver}"/>  
  <property name="url" value="${jdbc.url}"/>   
  <property name="username" value="${jdbc.username}"/>   
  <property name="password" value="${jdbc.password}"/>
  <property ...../>
</bean>
  • 自定义数据源路由类,继承【AbstractRoutingDataSource】
<bean id="multiRouteDataSource" class="com.*.MultiRouteDataSource" >   
   <-- 初始化一个默认的数据源 -->
   <property name="targetDataSources">      
      <map>
         <entry key="defaultTargetDataSource" value-ref="dataSource" ></entry>
      </map>
    <property name="defaultTargetDataSource" ref="dataSource" />
</bean>
  • 创建 sqlSessionFactory
<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
		<-- 引入自定义multiRouteDataSource -->
        <property name="dataSource" ref="multiRouteDataSource"/>
        <property name="configLocation" value="classpath:mybatis-config.xml"/>
        <property name="mapperLocations" value="classpath*:**/xml/**/*Mapper.xml"/>
</bean>
  • 创建 jdbcTemplate
<bean id="jdbcTemplate" class="org.springframework.jdbc.core.JdbcTemplate">   
		<-- 引入自定义multiRouteDataSource -->
        <property name="dataSource" ref="multiRouteDataSource"></property>
</bean>
  • 多数据源切换实现类
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.stat.DruidDataSourceStatManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 多数据源切换实现
 */
public class MultiRouteDataSource extends AbstractRoutingDataSource {

    private static final Logger logger = LoggerFactory.getLogger(MultiRouteDataSource.class);

    private Map<Object, Object> targetDataSources;

    /**
     * key : areaCode
     * value : dataSourceName
     */
    private static final Map <String,String> dbNameMapping = new ConcurrentHashMap <>();

    /**
     *  通过返回数据源名称实现多数据源的动态切换
     * @return
     */
    @Override
    protected Object determineCurrentLookupKey() {
        String areaCode = AreaCodeHolder.getAreaCode();
        String dataSourceName = null;
        if(areaCode == null){
            dataSourceName =  "defaultTargetDataSource";
        }else{
            dataSourceName = getDataSourceName(areaCode);
        }
        logger.debug("根据areaCode:[{}]切换到数据源:[{}]",areaCode,dataSourceName);
        return dataSourceName;
    }

    @Override
    public void setTargetDataSources(Map<Object, Object> targetDataSources) {
        super.setTargetDataSources(targetDataSources);
        this.targetDataSources = targetDataSources;
    }

    /**
     * 创建数据源
     * @param dataSourceList  数据源列表,包含数据库连接信息
     * @return
     */
    protected boolean createDataSource(List <Map <String,Object>> dataSourceList )  {
        try{
            Map<Object, Object> targetDataSources = this.targetDataSources;
            //遍历数据源列表,创建数据源,并添加到数据源列表(AbstractRoutingDataSource.targetDataSources) 中
            for(Map<String,Object> map : dataSourceList){
                //自定义的区域编码和自定义的数据源名称
                String areaCode = map.get("areaCode").toString();
                String dataSourceName = map.get("dataSourceName").toString();

                DruidDataSource druid = new DruidDataSource();
                for(String key : map.keySet()){
                    //这里我偷懒,前面一定要设置好map的key和druid的set方法一致,然后用工具类直接转换
                    org.apache.commons.beanutils.BeanUtils.setProperty(druid,key,map.get(key));
                }
                druid.init();
                targetDataSources.put(dataSourceName, druid);
                //维护areaCode 和 dataSourceName的映射关系,后面根据 areaCode 进行切库
                dbNameMapping.put(areaCode,dataSourceName);
            }
            //赋值TargetDataSources
            setTargetDataSources(targetDataSources);
            // 将TargetDataSources中的连接信息放入resolvedDataSources管理
            super.afterPropertiesSet();
            return true;
        }catch (Exception e){
            logger.error("createDataSource异常{}",e);
        }
        return false;

    }

    /**
     * 未用上
     * 删除数据源
     * @param datasourceid  数据源唯一标识
     * @return
     */
    protected boolean deleteDatasource(String datasourceid) {
        Map<Object, Object> targetDataSources = this.targetDataSources;
        if (targetDataSources.containsKey(datasourceid)) {
            Set <DruidDataSource> druidDataSourceInstances = DruidDataSourceStatManager.getDruidDataSourceInstances();
            for (DruidDataSource druid : druidDataSourceInstances) {
                if (datasourceid.equals(druid.getName())) {
                    targetDataSources.remove(datasourceid);
                    DruidDataSourceStatManager.removeDataSource(druid);
                    setTargetDataSources(targetDataSources);
                    super.afterPropertiesSet();
                    return true;
                }
            }
            return false;
        } else {
            return false;
        }
    }

    private String getDataSourceName(String areaCode){

        String dataSourceName = dbNameMapping.get(areaCode);
        if(dataSourceName == null){
            throw new RuntimeException("根据areaCode:"+areaCode +"无法查询到数据库连接!");
        }
        return dataSourceName;

    }

}
/**
 * 保存当前请求用户的区域编码
 */
public final class AreaCodeHolder {

    private static final  ThreadLocal<String> local = new ThreadLocal <>();

    public static  void setAreaCode(String var){
        local.set(var);
    }
    public static  String getAreaCode(){
       return  local.get();
    }


}
  • 配置监听,在spring加载完后,调用此方法实例化数据源列表
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 在IOC的容器的启动过程,当所有的bean都已经处理完成之后,spring ioc容器会有一个发布事件的动作
 */
@Component
public class DataSourceInitListener implements ApplicationListener<ContextRefreshedEvent> {

    private static final Logger logger = LoggerFactory.getLogger(DataSourceInitListener.class);

    @Override
    public void onApplicationEvent(ContextRefreshedEvent event) {
        long time = System.currentTimeMillis();
        boolean bool = false;
        try{
            logger.info("开始初始化数据源...");
            MultiRouteDataSource route =  event.getApplicationContext().getBean(MultiRouteDataSource.class);
            List <Map <String,Object>>  datasourcelist =  this.getDataSourceList();
            bool = route.createDataSource(datasourcelist);
        }catch (Exception e){
            logger.error("初始化数据源异常:{}",e);
        }
        logger.info("数据源初始化结果:{} 耗时:{} 毫秒",bool,(System.currentTimeMillis()-time));
    }
    
<!-- 模拟获取数据源,
        比如 :通过配置中心获取,
              调用远程接口获取,
              查询默认数据库获取  -->
    public List <Map <String,Object>> getDataSourceList(){

        List <Map <String,Object>> dataSourceList  = new ArrayList <>();
        {
            Map <String,Object> map = new HashMap <>();
            map.put("areaCode","0");
            map.put("dataSourceName","dataSource_1");
            map.put("username","username");
            map.put("password","password");
            map.put("url","jdbcUrl");
            map.put("driverClassName","driverClassName");
            dataSourceList.add(map);
        }
        {
            Map <String,Object> map = new HashMap <>();
            map.put("areaCode","1");
            map.put("dataSourceName","dataSource_2");
            map.put("username","username");
            map.put("password","password");
            map.put("url","jdbcUrl");
            map.put("driverClassName","driverClassName");
            dataSourceList.add(map);
        }
        return dataSourceList;
    }
}
  • 保存AreaCodeHolder信息

  • 比如,自定义一个filter, 用户一次登录的时候,将用户区域编码保存到AreaCodeHolder中。

  • 我项目中用的是dubbo ,使用的dubbo的隐式传参,将信息从消费端传递到服务端。所以我不需要用到【AreaCodeHolder】,因为dubbo的【RpcContext.getContext()】已经在ThreadLocal了。

消费者:
    RpcContext.getContext().setAttachment("areaCode" );
生产者:
   RpcContext.getContext().getAttachment("areaCode" );