-
我的需求是:一个区域对应一个数据源,根据登录用户的区域进行数据源切换。
-
创建默认数据源
<bean id="dataSource" class="com.alibaba.druid.pool.DruidDataSource" init-method="init" destroy-method="close">
<property name="driverClassName" value="${jdbc.driver}"/>
<property name="url" value="${jdbc.url}"/>
<property name="username" value="${jdbc.username}"/>
<property name="password" value="${jdbc.password}"/>
<property ...../>
</bean>
- 自定义数据源路由类,继承【AbstractRoutingDataSource】
<bean id="multiRouteDataSource" class="com.*.MultiRouteDataSource" >
<-- 初始化一个默认的数据源 -->
<property name="targetDataSources">
<map>
<entry key="defaultTargetDataSource" value-ref="dataSource" ></entry>
</map>
<property name="defaultTargetDataSource" ref="dataSource" />
</bean>
- 创建 sqlSessionFactory
<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
<-- 引入自定义multiRouteDataSource -->
<property name="dataSource" ref="multiRouteDataSource"/>
<property name="configLocation" value="classpath:mybatis-config.xml"/>
<property name="mapperLocations" value="classpath*:**/xml/**/*Mapper.xml"/>
</bean>
- 创建 jdbcTemplate
<bean id="jdbcTemplate" class="org.springframework.jdbc.core.JdbcTemplate">
<-- 引入自定义multiRouteDataSource -->
<property name="dataSource" ref="multiRouteDataSource"></property>
</bean>
- 多数据源切换实现类
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.stat.DruidDataSourceStatManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 多数据源切换实现
*/
public class MultiRouteDataSource extends AbstractRoutingDataSource {
private static final Logger logger = LoggerFactory.getLogger(MultiRouteDataSource.class);
private Map<Object, Object> targetDataSources;
/**
* key : areaCode
* value : dataSourceName
*/
private static final Map <String,String> dbNameMapping = new ConcurrentHashMap <>();
/**
* 通过返回数据源名称实现多数据源的动态切换
* @return
*/
@Override
protected Object determineCurrentLookupKey() {
String areaCode = AreaCodeHolder.getAreaCode();
String dataSourceName = null;
if(areaCode == null){
dataSourceName = "defaultTargetDataSource";
}else{
dataSourceName = getDataSourceName(areaCode);
}
logger.debug("根据areaCode:[{}]切换到数据源:[{}]",areaCode,dataSourceName);
return dataSourceName;
}
@Override
public void setTargetDataSources(Map<Object, Object> targetDataSources) {
super.setTargetDataSources(targetDataSources);
this.targetDataSources = targetDataSources;
}
/**
* 创建数据源
* @param dataSourceList 数据源列表,包含数据库连接信息
* @return
*/
protected boolean createDataSource(List <Map <String,Object>> dataSourceList ) {
try{
Map<Object, Object> targetDataSources = this.targetDataSources;
//遍历数据源列表,创建数据源,并添加到数据源列表(AbstractRoutingDataSource.targetDataSources) 中
for(Map<String,Object> map : dataSourceList){
//自定义的区域编码和自定义的数据源名称
String areaCode = map.get("areaCode").toString();
String dataSourceName = map.get("dataSourceName").toString();
DruidDataSource druid = new DruidDataSource();
for(String key : map.keySet()){
//这里我偷懒,前面一定要设置好map的key和druid的set方法一致,然后用工具类直接转换
org.apache.commons.beanutils.BeanUtils.setProperty(druid,key,map.get(key));
}
druid.init();
targetDataSources.put(dataSourceName, druid);
//维护areaCode 和 dataSourceName的映射关系,后面根据 areaCode 进行切库
dbNameMapping.put(areaCode,dataSourceName);
}
//赋值TargetDataSources
setTargetDataSources(targetDataSources);
// 将TargetDataSources中的连接信息放入resolvedDataSources管理
super.afterPropertiesSet();
return true;
}catch (Exception e){
logger.error("createDataSource异常{}",e);
}
return false;
}
/**
* 未用上
* 删除数据源
* @param datasourceid 数据源唯一标识
* @return
*/
protected boolean deleteDatasource(String datasourceid) {
Map<Object, Object> targetDataSources = this.targetDataSources;
if (targetDataSources.containsKey(datasourceid)) {
Set <DruidDataSource> druidDataSourceInstances = DruidDataSourceStatManager.getDruidDataSourceInstances();
for (DruidDataSource druid : druidDataSourceInstances) {
if (datasourceid.equals(druid.getName())) {
targetDataSources.remove(datasourceid);
DruidDataSourceStatManager.removeDataSource(druid);
setTargetDataSources(targetDataSources);
super.afterPropertiesSet();
return true;
}
}
return false;
} else {
return false;
}
}
private String getDataSourceName(String areaCode){
String dataSourceName = dbNameMapping.get(areaCode);
if(dataSourceName == null){
throw new RuntimeException("根据areaCode:"+areaCode +"无法查询到数据库连接!");
}
return dataSourceName;
}
}
/**
* 保存当前请求用户的区域编码
*/
public final class AreaCodeHolder {
private static final ThreadLocal<String> local = new ThreadLocal <>();
public static void setAreaCode(String var){
local.set(var);
}
public static String getAreaCode(){
return local.get();
}
}
- 配置监听,在spring加载完后,调用此方法实例化数据源列表
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 在IOC的容器的启动过程,当所有的bean都已经处理完成之后,spring ioc容器会有一个发布事件的动作
*/
@Component
public class DataSourceInitListener implements ApplicationListener<ContextRefreshedEvent> {
private static final Logger logger = LoggerFactory.getLogger(DataSourceInitListener.class);
@Override
public void onApplicationEvent(ContextRefreshedEvent event) {
long time = System.currentTimeMillis();
boolean bool = false;
try{
logger.info("开始初始化数据源...");
MultiRouteDataSource route = event.getApplicationContext().getBean(MultiRouteDataSource.class);
List <Map <String,Object>> datasourcelist = this.getDataSourceList();
bool = route.createDataSource(datasourcelist);
}catch (Exception e){
logger.error("初始化数据源异常:{}",e);
}
logger.info("数据源初始化结果:{} 耗时:{} 毫秒",bool,(System.currentTimeMillis()-time));
}
<!-- 模拟获取数据源,
比如 :通过配置中心获取,
调用远程接口获取,
查询默认数据库获取 -->
public List <Map <String,Object>> getDataSourceList(){
List <Map <String,Object>> dataSourceList = new ArrayList <>();
{
Map <String,Object> map = new HashMap <>();
map.put("areaCode","0");
map.put("dataSourceName","dataSource_1");
map.put("username","username");
map.put("password","password");
map.put("url","jdbcUrl");
map.put("driverClassName","driverClassName");
dataSourceList.add(map);
}
{
Map <String,Object> map = new HashMap <>();
map.put("areaCode","1");
map.put("dataSourceName","dataSource_2");
map.put("username","username");
map.put("password","password");
map.put("url","jdbcUrl");
map.put("driverClassName","driverClassName");
dataSourceList.add(map);
}
return dataSourceList;
}
}
-
保存AreaCodeHolder信息
-
比如,自定义一个filter, 用户一次登录的时候,将用户区域编码保存到AreaCodeHolder中。
-
我项目中用的是dubbo ,使用的dubbo的隐式传参,将信息从消费端传递到服务端。所以我不需要用到【AreaCodeHolder】,因为dubbo的【RpcContext.getContext()】已经在ThreadLocal了。
消费者:
RpcContext.getContext().setAttachment("areaCode" );
生产者:
RpcContext.getContext().getAttachment("areaCode" );