Spring Boot整合JPA实现多数据源切换(Oracle、Mysql)

626 阅读2分钟

场景

在代码中,我们希望通过代码去查询不同数据库、不同表中的数据

实现方案

  • 步骤
  1. 核心是实现Spring 内置的 AbstractRoutingDataSource 抽象类,实现determineCurrentLookupKey() 方法;
  2. 实例化出多个datasource,并将这些datasource加入到该抽象类中的targetDataSources。
  3. 在操作数据前,dao层会先调用 AbstractRoutingDataSource 抽象类getConnection(),其中的调用的是determineCurrentLookupKey(),获取到对应的datasource;
  4. 借助ThreadLocal保存数据源信息,在整个业务流程中方便获取;
  • 难点

不同连接池实例化事务管理、和生成类似数据库的sessionFactory方法有所不同,可能会费些时间。

代码实现

1、环境介绍

  • springboot版本: 1.5.6.RELEASE
  • 连接池:alibaba druid

2、配置文件

spring:
  main:
    allow-bean-definition-overriding: true
  datasource: #多数据源配置
    # 默认数据库配置
    default:
      driver-class-name: com.mysql.cj.jdbc.Driver
      url: jdbc:mysql://localhost:3306/mysql1?useUnicode=true&characterEncoding=UTF-8&useJDBCCompliantTimezoneShift=true&useLegacyDatetimeCode=false&serverTimezone=Asia/Shanghai
      username: root
      password: 123456
    # oracle数据库配置
    oracle:
      driver-class-name: oracle.jdbc.driver.OracleDriver
      url: jdbc:oracle:thin:@192.168.1.188:1521/csta
      username: root
      password: root
    # mysql数据库配置
    mysql:
      driver-class-name: com.mysql.cj.jdbc.Driver
      url: jdbc:mysql://localhost:3306/mysql2?useUnicode=true&characterEncoding=UTF-8&useJDBCCompliantTimezoneShift=true&useLegacyDatetimeCode=false&serverTimezone=Asia/Shanghai
      username: root
      password: 123456
    type: com.alibaba.druid.pool.DruidDataSource
    # druid数据库连接池
    druid:
      filters: stat
      initialSize: 5
      maxActive: 20
      maxPoolPreparedStatementPerConnectionSize: 20
      maxWait: 60000
      minEvictableIdleTimeMillis: 30000
      minIdle: 5
      poolPreparedStatements: false
      testOnBorrow: false
      testOnReturn: false
      testWhileIdle: true
      timeBetweenEvictionRunsMillis: 60000
      #Oracle模式
      validation-query: SELECT 1 FROM DUAL #用来检测连接是否有效的sql
      #MySQL模式
      validation-queryM: SELECT 1 #用来检测连接是否有效的sql
  jpa:
    hibernate:
      ddl-auto: none
      oracle-dialect: org.hibernate.dialect.Oracle10gDialect
      mysql-dialect: org.hibernate.dialect.MySQL8Dialect
    show-sql: false

3、代码实现如下

  • 创建类 DataSourceConfiguration
@Configuration
@Slf4j
public class DataSourceConfiguration {

    private final static String DEFAULT_DATA_SOURCE = "defaultDataSource";
    private final static String MYSQL_DATA_SOURCE = "mysqlDataSource";
    private final static String ORACLE_DATA_SOURCE = "oracleDataSource";

    @Value("${spring.datasource.type}")
    private Class<? extends DataSource> dataSourceType;

    @Primary
    @Bean(value = DEFAULT_DATA_SOURCE)
    @Qualifier(DEFAULT_DATA_SOURCE)
    @ConfigurationProperties(prefix = "spring.datasource.default")
    public DataSource defaultDataSource() {
        log.info("create default(mysql) datasource...");
        return DataSourceBuilder.create().type(dataSourceType).build();
    }

    @Bean(value = MYSQL_DATA_SOURCE)
    @Qualifier(MYSQL_DATA_SOURCE)
    @ConfigurationProperties(prefix = "spring.datasource.mysql")
    public DataSource mysqlDataSource() {
        log.info("create mysql datasource...");
        return DataSourceBuilder.create().type(dataSourceType).build();
    }

    @Bean(value = ORACLE_DATA_SOURCE)
    @Qualifier(ORACLE_DATA_SOURCE)
    @ConfigurationProperties(prefix = "spring.datasource.oracle")
    public DataSource oracleDataSource() {
        log.info("create oracle datasource...");
        return DataSourceBuilder.create().type(dataSourceType).build();
    }

    @Bean(name = "routingDataSource")
    public AbstractRoutingDataSource routingDataSource(@Qualifier("defaultDataSource") DataSource defaultDataSource,
                                                       @Qualifier("mysqlDataSource") DataSource mysqlDataSource,
                                                       @Qualifier("oracleDataSource") DataSource oracleDataSource) {
        DynamicDataSourceRouter proxy = new DynamicDataSourceRouter();
        Map<Object, Object> targetDataSources = new HashMap<>(3);
        targetDataSources.put(DEFAULT_DATA_SOURCE, defaultDataSource);
        targetDataSources.put(MYSQL_DATA_SOURCE, mysqlDataSource);
        targetDataSources.put(ORACLE_DATA_SOURCE, oracleDataSource);

        proxy.setDefaultTargetDataSource(defaultDataSource);
        proxy.setTargetDataSources(targetDataSources);
        return proxy;
    }
}
  • 使用 ThreadLocal 来动态设置和保存数据源类型的key,创建类:DataSourceContextHolder
public class DataSourceContextHolder {
    private static final ThreadLocal<String> holder = new ThreadLocal<>();

    public static void setDataSource(String type) {
        holder.set(type);
    }

    public static String getDataSource() {
        String lookUpKey = holder.get();
        return lookUpKey == null ? "defaultDataSource" : lookUpKey;
    }

    public static void clear() {
        holder.remove();
    }
}
  • 编写一个类继承AbstractRoutingDataSource,并重写 determineCurrentLookupKey 这个路由方法:
public class DynamicDataSourceRouter extends AbstractRoutingDataSource {
    @Override
    protected Object determineCurrentLookupKey() {
        return DataSourceContextHolder.getDataSource();
    }
}
  • 创建Jpa管理
@Configuration
@EnableConfigurationProperties(JpaProperties.class)
@EnableJpaRepositories(value = {"com.auto.code.**.repository"})
public class JpaEntityManager {

    @Autowired
    private JpaProperties jpaProperties;

    @Resource(name = "routingDataSource")
    private DataSource routingDataSource;

    //@Primary
    @Bean(name = "entityManagerFactoryBean")
    public LocalContainerEntityManagerFactoryBean entityManagerFactoryBean(EntityManagerFactoryBuilder builder) {
        // 不明白为什么这里获取不到 application.yml 里的配置
        Map<String, String> properties = jpaProperties.getProperties();
        //要设置这个属性,实现 CamelCase -> UnderScore 的转换
        properties.put("hibernate.physical_naming_strategy",
                "org.springframework.boot.orm.jpa.hibernate.SpringPhysicalNamingStrategy");

        return builder
                .dataSource(routingDataSource)//关键:注入routingDataSource
                .properties(properties)
                .packages("com.auto.code.**.entity")
                .persistenceUnit("myPersistenceUnit")
                .build();
    }

    @Primary
    @Bean(name = "entityManagerFactory")
    public EntityManagerFactory entityManagerFactory(EntityManagerFactoryBuilder builder) {
        return this.entityManagerFactoryBean(builder).getObject();
    }

    @Primary
    @Bean(name = "transactionManager")
    public PlatformTransactionManager transactionManager(EntityManagerFactoryBuilder builder) {
        return new JpaTransactionManager(entityManagerFactory(builder));
    }
}
  • 通过切面,来实现,请求过来的数据,是读取那个数据源
@Slf4j
@Aspect
@Component
public class DynamicDataSourceAspect {

    @Pointcut("execution(* com.auto.code..*.service..*.*(..))")
    private void aspect() {
    }

    @Around("aspect()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        String method = joinPoint.getSignature().getName();
        if (method.contains("Default")) {
            DataSourceContextHolder.setDataSource("defaultDataSource");
            log.info("switch to default datasource...");
        } else if (method.contains("Mysql")) {
            DataSourceContextHolder.setDataSource("mysqlDataSource");
            log.info("switch to mysql datasource...");
        } else {
            DataSourceContextHolder.setDataSource("oracleDataSource");
            log.info("switch to oracle datasource...");
        }

        try {
            return joinPoint.proceed();
        } finally {
            log.info("清除 datasource router...");
            DataSourceContextHolder.clear();
        }
    }
}

至此,核心代码实现完毕。