Java解决单机环境下多数据源的事务问题

springboot单机环境下的@Transictional可以保证事务,但多数据源的情况就无法使用了,这里简单实现一下多数据源的情况下如何保证事务。

一,事务实现方案

利用 ThreadLocal 将事务方法 内用到的 connection 缓存起来,当业务执行完毕,再统一 commit 或者 rollback;

二,自定义开启事务注解

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface MultiDSTransaction {
}

三,yml配置多数据源

spring:
  datasource:
    type: com.alibaba.druid.pool.DruidDataSource
    datasource1:
      url: jdbc:mysql://localhost:3306/datasource1?serverTimezone=UTC&useUnicode=true
      username: root
      password: 1234
      driver-class-name: com.mysql.cj.jdbc.Driver
    datasource2:
      type: com.alibaba.druid.pool.DruidDataSource
      url: jdbc:mysql://localhost:3306/datasource2?serverTimezone=UTC&useUnicode=true
      username: root
      password: 1234
      driver-class-name: com.mysql.cj.jdbc.Driver

四,多数据源配置类

/**
 * 注入两个数据源
 */
@Configuration
public class DataSourceConfig {

    @Bean
    @ConfigurationProperties(prefix = "spring.datasource.datasource1")
    public DataSource dataSource1(){
        return DruidDataSourceBuilder.create().build();
    }

    @Bean
    @ConfigurationProperties(prefix = "spring.datasource.datasource2")
    public DataSource dataSource2(){
        return DruidDataSourceBuilder.create().build();
    }

五,重写Connection连接

package com.example.demo.config;

import com.example.demo.config.TransactionContext;

import java.sql.*;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.Executor;

public class CustomConnection implements Connection {
    // 真实的连接
    private Connection connection;

    public CustomConnection(Connection connection) {
        this.connection = connection;
    }
    @Override
    public void commit() throws SQLException {
        // 如果没开启多数据源事务,则走 commit
        if (!TransactionContext.isOpenTran()) {
            connection.commit();
        }
    }

    @Override
    public void rollback() throws SQLException {
        connection.rollback();
    }

    public void commitMultiDbTran() throws SQLException {
        // 如果开启多数据源,则走 这里的 commit
        connection.commit();
    }
    @Override
    public void close() throws SQLException {
        // mybatis 执行完业务后,会触发 close() 操作,如果 connection 被提前 close 了,业务就会出错
        if (!TransactionContext.isOpenTran()) {
            connection.close();
        }
    }
    public void closeMultiDbTran() throws SQLException {
        // 如果开启多数据源事务,则走 这里的 close
        connection.close();
    }

    @Override
    public Statement createStatement() throws SQLException {
        return connection.createStatement();
    }

    @Override
    public PreparedStatement prepareStatement(String sql) throws SQLException {
        return connection.prepareStatement(sql);
    }

    @Override
    public CallableStatement prepareCall(String sql) throws SQLException {
        return connection.prepareCall(sql);
    }

    @Override
    public String nativeSQL(String sql) throws SQLException {
        return connection.nativeSQL(sql);
    }

    @Override
    public void setAutoCommit(boolean autoCommit) throws SQLException {
        connection.setAutoCommit(autoCommit);
    }

    @Override
    public boolean getAutoCommit() throws SQLException {
        return connection.getAutoCommit();
    }

    @Override
    public boolean isClosed() throws SQLException {
        return connection.isClosed();
    }

    @Override
    public DatabaseMetaData getMetaData() throws SQLException {
        return connection.getMetaData();
    }

    @Override
    public void setReadOnly(boolean readOnly) throws SQLException {
        connection.setReadOnly(readOnly);
    }

    @Override
    public boolean isReadOnly() throws SQLException {
        return connection.isReadOnly();
    }

    @Override
    public void setCatalog(String catalog) throws SQLException {
        connection.setCatalog(catalog);
    }

    @Override
    public String getCatalog() throws SQLException {
        return connection.getCatalog();
    }

    @Override
    public void setTransactionIsolation(int level) throws SQLException {
        connection.setTransactionIsolation(level);
    }

    @Override
    public int getTransactionIsolation() throws SQLException {
        return connection.getTransactionIsolation();
    }

    @Override
    public SQLWarning getWarnings() throws SQLException {
        return connection.getWarnings();
    }

    @Override
    public void clearWarnings() throws SQLException {
        connection.clearWarnings();
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
        return connection.createStatement(resultSetType,resultSetConcurrency);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
        return connection.prepareStatement(sql,resultSetType,resultSetConcurrency);
    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
        return connection.prepareCall(sql,resultSetType,resultSetConcurrency);
    }

    @Override
    public Map<String, Class<?>> getTypeMap() throws SQLException {
        return connection.getTypeMap();
    }

    @Override
    public void setTypeMap(Map<String, Class<?>> map) throws SQLException {
        connection.setTypeMap(map);
    }

    @Override
    public void setHoldability(int holdability) throws SQLException {
        connection.setHoldability(holdability);
    }

    @Override
    public int getHoldability() throws SQLException {
        return connection.getHoldability();
    }

    @Override
    public Savepoint setSavepoint() throws SQLException {
        return connection.setSavepoint();
    }

    @Override
    public Savepoint setSavepoint(String name) throws SQLException {
        return connection.setSavepoint(name);
    }

    @Override
    public void rollback(Savepoint savepoint) throws SQLException {
        connection.rollback(savepoint);
    }

    @Override
    public void releaseSavepoint(Savepoint savepoint) throws SQLException {
        connection.releaseSavepoint(savepoint);
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
        return connection.createStatement(resultSetType,resultSetConcurrency,resultSetHoldability);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
        return connection.prepareStatement(sql,resultSetType,resultSetConcurrency,resultSetHoldability);
    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
        return connection.prepareCall(sql,resultSetType,resultSetConcurrency,resultSetHoldability);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
        return connection.prepareStatement(sql,autoGeneratedKeys);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
        return connection.prepareStatement(sql,columnIndexes);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
        return connection.prepareStatement(sql,columnNames);
    }

    @Override
    public Clob createClob() throws SQLException {
        return connection.createClob();
    }

    @Override
    public Blob createBlob() throws SQLException {
        return connection.createBlob();
    }

    @Override
    public NClob createNClob() throws SQLException {
        return connection.createNClob();
    }

    @Override
    public SQLXML createSQLXML() throws SQLException {
        return connection.createSQLXML();
    }

    @Override
    public boolean isValid(int timeout) throws SQLException {
        return connection.isValid(timeout);
    }

    @Override
    public void setClientInfo(String name, String value) throws SQLClientInfoException {
        connection.setClientInfo(name,value);
    }

    @Override
    public void setClientInfo(Properties properties) throws SQLClientInfoException {
        connection.setClientInfo(properties);
    }

    @Override
    public String getClientInfo(String name) throws SQLException {
        return connection.getClientInfo(name);
    }

    @Override
    public Properties getClientInfo() throws SQLException {
        return connection.getClientInfo();
    }

    @Override
    public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
        return connection.createArrayOf(typeName,elements);
    }

    @Override
    public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
        return connection.createStruct(typeName,attributes);
    }

    @Override
    public void setSchema(String schema) throws SQLException {
        connection.setSchema(schema);
    }

    @Override
    public String getSchema() throws SQLException {
        return connection.getSchema();
    }

    @Override
    public void abort(Executor executor) throws SQLException {
        connection.abort(executor);
    }

    @Override
    public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
        connection.setNetworkTimeout(executor,milliseconds);
    }

    @Override
    public int getNetworkTimeout() throws SQLException {
        return connection.getNetworkTimeout();
    }

    @Override
    public <T> T unwrap(Class<T> iface) throws SQLException {
        return connection.unwrap(iface);
    }

    @Override
    public boolean isWrapperFor(Class<?> iface) throws SQLException {
        return connection.isWrapperFor(iface);
    }
}

六,事务管理Context

package com.example.demo.config;

public class TransactionContext {
    private static final ThreadLocal<Boolean> TRAN_SWITCH_CONTEXT = new ThreadLocal<>();
    static {
        // 默认事务处于关闭状态
        TRAN_SWITCH_CONTEXT.set(false);
    }
    // 开启事务
    public static void openTran() {
        TRAN_SWITCH_CONTEXT.set(true);
    }
    // 关闭事务
    public static void closeTran() {
        TRAN_SWITCH_CONTEXT.set(false);
    }
    // 判断是否开启事务
    public static Boolean isOpenTran() {
        return TRAN_SWITCH_CONTEXT.get();
    }
}

七,自定义数据源

在自定义数据源中注入上边那两个多数据源,维持多数据源执行事务期间用到的连接列表,在自定义数据源中添加事务相关业务,既在获取 连接的地方将 Connection 缓存到 ThreadLocal 中

使用了@Primary注解后,作用是将该bean设置为主要注入bean,当注入相同类型的datasource的bean时就不会注入DataSourceConfig配置类中注入的两个bean了,只会注入这个,mybatis在使用DataSourceUtil.getDataSource的时候获取的是这个自定义数据源,执行的是此自定义数据源的getConnection方法。

package com.example.demo.config;

import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Component;

import javax.sql.DataSource;
import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;

/**
 * 管理多个数据源
 */
@Component
@Primary //将该bean设置为主要注入bean,当注入相同类型的datasource的bean时就不会注入DataSourceConfig配置类中注入的两个bean了,只会注入这个
public class DynamicDataSource implements DataSource, InitializingBean {

    /**
     * 多数据源 执行 事务期间用到的连接
     */
    public static final ThreadLocal<List<CustomConnection>> MULTI_TRAN_CONNECTION = new ThreadLocal<>();

    //当前使用的数据源标识
    public static ThreadLocal<String> name = new ThreadLocal<>();
    //两个数据源
    @Autowired
    private DataSource dataSource1;
    @Autowired
    private DataSource dataSource2;

    @Override
    public Connection getConnection() throws SQLException {
        Connection connection = null;
        if(name.get().equals("one")){
            connection = dataSource1.getConnection();
        }else if(name.get().equals("two")){
            connection = dataSource2.getConnection();
        }
        CustomConnection customConnection = new CustomConnection(connection);
        if (TransactionContext.isOpenTran()) {
            customConnection.setAutoCommit(false);
            List<CustomConnection> customConnections = MULTI_TRAN_CONNECTION.get();
            if(customConnections == null){
                customConnections = new ArrayList<>();
            }
            customConnections.add(customConnection);
            MULTI_TRAN_CONNECTION.set(customConnections);
        }
        return customConnection;
    }

    @Override
    public Connection getConnection(String username, String password) throws SQLException {
        return null;
    }

    @Override
    public <T> T unwrap(Class<T> iface) throws SQLException {
        return null;
    }

    @Override
    public boolean isWrapperFor(Class<?> iface) throws SQLException {
        return false;
    }

    @Override
    public PrintWriter getLogWriter() throws SQLException {
        return null;
    }

    @Override
    public void setLogWriter(PrintWriter out) throws SQLException {

    }

    @Override
    public void setLoginTimeout(int seconds) throws SQLException {

    }

    @Override
    public int getLoginTimeout() throws SQLException {
        return 0;
    }

    @Override
    public Logger getParentLogger() throws SQLFeatureNotSupportedException {
        return null;
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        //初始化,默认设置当前连接要获取的数据源是datasource1
        name.set("one");
    }
}

八,写切面

利用 AOP 进行方法拦截,对使用了 多数据源 事务注解的方法,执行事务业务

package com.example.demo.aspect;

import com.example.demo.config.CustomConnection;
import com.example.demo.config.DynamicDataSource;
import com.example.demo.config.TransactionContext;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.context.annotation.Configuration;

@Aspect
@Configuration
public class MultiDSTransactionConfig {
    @Pointcut("@annotation(com.example.demo.annotation.MultiDSTransaction)")
    public void transactPoint() {}

    @Around("transactPoint()")
    public Object multiTranAop(ProceedingJoinPoint joinPoint) throws Throwable {
        // 开启事务
        TransactionContext.openTran();
        try {
            // 执行业务
            Object proceed = joinPoint.proceed();
            // 提交事务
            for (CustomConnection connection : DynamicDataSource.MULTI_TRAN_CONNECTION.get()) {
                connection.commitMultiDbTran();
                connection.closeMultiDbTran();
            }
            return proceed;
        } catch (Throwable t) {
            for (CustomConnection connection : DynamicDataSource.MULTI_TRAN_CONNECTION.get()) {
                // 事务回滚
                connection.rollback();
                connection.closeMultiDbTran();
            }
            throw t;
        } finally {
            // 清空 事务 连接,关闭当前事务
            DynamicDataSource.MULTI_TRAN_CONNECTION.get().clear();
            TransactionContext.closeTran();
        }
    }
}

九,测试

测试代码如下,当报错的之后,事务同时回滚,数据没插入成功,当未出现报错,数据则都插入成功

@Mapper
public interface TestMapper {

    /**
     * 测试datasource1
     * @return
     */
    @Insert(value = "insert into test1(name) values ('呵呵');")
    int test1();

    /**
     * 测试datasource2
     * @return
     */
    @Insert(value = "insert into test2(name) values ('哈哈');")
    int test2();
}
@RestController
public class TestController {

    @Autowired
    private TestMapper testMapper;

    @RequestMapping("test")
    @MultiDSTransaction
    public void test(){
        //选择datasource1数据源
        DynamicDataSource.name.set("one");
        int i = testMapper.test1();

        //模拟报错
        int k = 1/0;

        //选择datasource2数据源
        DynamicDataSource.name.remove();
        DynamicDataSource.name.set("two");
        int j = testMapper.test2();

        System.out.println("over");
    }
}
posted @ 2022-11-07 19:51  你樊不樊  阅读(1011)  评论(0编辑  收藏  举报