手写分布式事务
这个例子仿照seata 的 AT 模式
分布式事务产生:
其中localsql和other方法都是对当前服务数据库进行查询,但remoteMthod接口调用的远程服务库,单纯使用Spring的@Transactional注解无法回滚其他服务
思路分析:
首先要解决这个问题
1.要知道localsql remoteMethod others 这三个方法是同一个事务
2.执行remoteMethod 最终sucess还是error,需要通知给这个事务,最后进行提交或者回滚操作
3.事务的控制
解决方案:
1.需要单独创建一个transactionManager来协调事务之间的关系,即事务管理中心,将事务的执行结果通知给事务中心。由事务中心反馈是否提交或回滚,可以使用Netty做通讯(解决事务可见性问题)
2.可以在TM上使用map去标记一个全局事务 Map<全局事务ID, List[分支事务1,分支事务2....]>,方便事务管理 (解决事务分组问题)
3.自定义注解,使用AOP设置一个切面,在执行方法之前,执行4步(1.建立连接 2.开启事务 3.执行方法 4.提交/回滚) (解决事务控制权)
实现:
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-parent</artifactId> <version>2.1.0.RELEASE</version> <relativePath/> <!-- lookup parent from repository --> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>server1</artifactId> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> <version>2.1.0.RELEASE</version> </dependency> <dependency> <groupId>org.mybatis.spring.boot</groupId> <artifactId>mybatis-spring-boot-starter</artifactId> <version>1.3.2</version> </dependency> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>8.0.15</version> </dependency> <dependency> <groupId>com.luban</groupId> <artifactId>lbglobaltransaction</artifactId> <version>1.0-SNAPSHOT</version> </dependency> </dependencies> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> <source>8</source> <target>8</target> </configuration> </plugin> </plugins> </build> </project>
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @RestController @RequestMapping("server1") public class DemoController { @Autowired private DemoService demoService; @RequestMapping(value = "test") public void test() { demoService.test(); } }
import com.hz.lbtransaction.annotation.Lbtransactional; import com.hz.lbtransaction.util.HttpClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @Service public class DemoService { @Autowired private DemoDao demoDao; @Lbtransactional(isStart = true) @Transactional public void test() { demoDao.insert("server1"); HttpClient.get("http://localhost:8082/server2/test"); int i = 100/0; } }
spring.datasource.driverClassName = com.mysql.jdbc.Driver spring.datasource.url = jdbc:mysql://localhost:3306/hs?serverTimezone=UTC&useUnicode=true&characterEncoding=utf-8&AllowPublicKeyRetrieval=True spring.datasource.username = root spring.datasource.password = root mybatis.type-aliases-package=com.luban.server mybatis.mapper-locations=classpath:*.xml mybatis.configuration.log-impl=org.apache.ibatis.logging.stdout.StdOutImpl server.port=8081
Server2项目和Server1相似,以下不同需要注意
import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.string.StringDecoder; import io.netty.handler.codec.string.StringEncoder; public class NettyServer { public void start(String hostName, int port) { try { final ServerBootstrap bootstrap = new ServerBootstrap(); NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup(); bootstrap.group(eventLoopGroup) .channel(NioServerSocketChannel.class) .childHandler(new ChannelInitializer<SocketChannel>() { protected void initChannel(SocketChannel socketChannel) throws Exception { ChannelPipeline pipeline = socketChannel.pipeline(); pipeline.addLast("decoder", new StringDecoder()); pipeline.addLast("encoder", new StringEncoder()); pipeline.addLast("handler", new NettyServerHandler()); } }); bootstrap.bind(hostName, port).sync(); } catch (InterruptedException e) { e.printStackTrace(); } } public void close() { } }
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.group.ChannelGroup; import io.netty.channel.group.DefaultChannelGroup; import io.netty.util.concurrent.GlobalEventExecutor; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * 作为事务管理者,它需要: * 1. 创建并保存事务组 * 2. 保存各个子事务在对应的事务组内 * 3. 统计并判断事务组内的各个子事务状态,以算出当前事务组的状态(提交or回滚) * 4. 通知各个子事务提交或回滚 */ public class NettyServerHandler extends ChannelInboundHandlerAdapter { private static ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); // 事务组中的事务状态列表 private static Map<String, List<String>> transactionTypeMap = new HashMap<String, List<String>>(); // 事务组是否已经接收到结束的标记 private static Map<String, Boolean> isEndMap = new HashMap<String, Boolean>(); // 事务组中应该有的事务个数 private static Map<String, Integer> transactionCountMap = new HashMap<String, Integer>(); @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { Channel channel = ctx.channel(); channelGroup.add(ctx.channel()); } /** * * {groupId:List<子事务>} * 1. 接收创建事务组事件 * 2. 接收子事务的注册事件 * 3. 判断事务组的状态,如果该事务组中有一个事务需要回滚,那么整个事务组就需要回滚,反之,则整个事务组提交 * 4. 通知所有客户端进行提交或回滚 * @param ctx * @param msg * @throws Exception */ @Override public synchronized void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { System.out.println("接受数据:" + msg.toString()); JSONObject jsonObject = JSON.parseObject((String) msg); String command = jsonObject.getString("command"); // create-创建事务组,add-添加事务 String groupId = jsonObject.getString("groupId"); // 事务组id String transactionType = jsonObject.getString("transactionType"); // 子事务类型,commit-待提交,rollback-待回滚 Integer transactionCount = jsonObject.getInteger("transactionCount"); // 事务数量 Boolean isEnd = jsonObject.getBoolean("isEnd"); // 是否是结束事务 if ("create".equals(command)) { // 创建事务组 transactionTypeMap.put(groupId, new ArrayList<String>()); } else if ("add".equals(command)) { // 加入事务组 transactionTypeMap.get(groupId).add(transactionType); if (isEnd) { isEndMap.put(groupId, true); transactionCountMap.put(groupId, transactionCount); } JSONObject result = new JSONObject(); result.put("groupId", groupId); // 如果已经接收到结束事务的标记,比较事务是否已经全部到达,如果已经全部到达则看是否需要回滚 if (isEndMap.get(groupId) && transactionCountMap.get(groupId).equals(transactionTypeMap.get(groupId).size())) { if (transactionTypeMap.get(groupId).contains("rollback")){ result.put("command", "rollback"); sendResult(result); } else { result.put("command", "commit"); sendResult(result); } } } } private void sendResult(JSONObject result) { for (Channel channel : channelGroup) { System.out.println("发送数据:" + result.toJSONString()); channel.writeAndFlush(result.toJSONString()); } } }
public class TxManagerMain { public static void main(String[] args) { NettyServer nettyServer = new NettyServer(); nettyServer.start("localhost", 8080); System.out.println("netty 启动成功"); } }
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <artifactId>lbtransaction</artifactId> <groupId>com.luban</groupId> <version>1.0-SNAPSHOT</version> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>lbglobaltransaction</artifactId> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> <version>2.1.0.RELEASE</version> </dependency> <dependency> <groupId>io.netty</groupId> <artifactId>netty-all</artifactId> <version>4.1.16.Final</version> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> <version>4.5.4</version> </dependency> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>1.2.51</version> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-aop</artifactId> <version>2.1.0.RELEASE</version> </dependency> </dependencies> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> <source>8</source> <target>8</target> </configuration> </plugin> </plugins> </build> </project>
import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; @Target({ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) public @interface Lbtransactional { // 代表属于分布式事务 boolean isStart() default false; boolean isEnd() default false; }
import com.hz.lbtransaction.connection.LbConnection; import com.hz.lbtransaction.transactional.LbTransactionManager; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.springframework.stereotype.Component; import java.sql.Connection; @Aspect @Component public class LbDataSourceAspect { /** * 切的是一个接口,所以所有的实现类都会被切到 * spring肯定会调用这个方法来生成一个本地事务 * 所以point.proceed()返回的也是一个Connection * @param point * @return * @throws Throwable */ @Around("execution(* javax.sql.DataSource.getConnection(..))") public Connection around(ProceedingJoinPoint point) throws Throwable { if (LbTransactionManager.getCurrent() != null) { return new LbConnection((Connection) point.proceed(), LbTransactionManager.getCurrent()); } else { return (Connection) point.proceed(); } } }
import com.hz.lbtransaction.annotation.Lbtransactional; import com.hz.lbtransaction.transactional.LbTransaction; import com.hz.lbtransaction.transactional.LbTransactionManager; import com.hz.lbtransaction.transactional.TransactionType; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.reflect.MethodSignature; import org.springframework.core.Ordered; import org.springframework.stereotype.Component; import java.lang.reflect.Method; @Aspect @Component public class LbTransactionAspect implements Ordered { @Around("@annotation(com.hz.lbtransaction.annotation.Lbtransactional)") public void invoke(ProceedingJoinPoint point) { // 打印出这个注解所对应的方法 MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); Lbtransactional lbAnnotation = method.getAnnotation(Lbtransactional.class); String groupId = ""; if (lbAnnotation.isStart()) { groupId = LbTransactionManager.createLbTransactionGroup(); } else { groupId = LbTransactionManager.getCurrentGroupId(); } LbTransaction lbTransaction = LbTransactionManager.createLbTransaction(groupId); try { // spring会开启mysql事务 point.proceed(); LbTransactionManager.addLbTransaction(lbTransaction, lbAnnotation.isEnd(), TransactionType.commit); } catch (Exception e) { LbTransactionManager.addLbTransaction(lbTransaction, lbAnnotation.isEnd(), TransactionType.rollback); e.printStackTrace(); } catch (Throwable throwable) { LbTransactionManager.addLbTransaction(lbTransaction, lbAnnotation.isEnd(), TransactionType.rollback); throwable.printStackTrace(); } } @Override public int getOrder() { return 10000; } }
import com.hz.lbtransaction.transactional.TransactionType; import com.hz.lbtransaction.transactional.LbTransaction; import java.sql.*; import java.util.Map; import java.util.Properties; import java.util.concurrent.Executor; public class LbConnection implements Connection { private Connection connection; private LbTransaction lbTransaction; public LbConnection(Connection connection, LbTransaction lbTransaction) { this.connection = connection; this.lbTransaction = lbTransaction; } @Override public void commit() throws SQLException { // 要提交的时候先不提交,等TxManager的通知再提交 // 日志 new Thread(new Runnable() { @Override public void run() { try { lbTransaction.getTask().waitTask(); if (lbTransaction.getTransactionType().equals(TransactionType.rollback)) { connection.rollback(); } else { connection.commit(); } connection.close(); } catch (SQLException e) { e.printStackTrace(); } } }).start(); } @Override public void rollback() throws SQLException { new Thread(new Runnable() { @Override public void run() { try { lbTransaction.getTask().waitTask(); connection.rollback(); connection.close(); } catch (SQLException e) { e.printStackTrace(); } } }).start(); } @Override public void close() throws SQLException { // connection.close(); } /** * default */ @Override public Statement createStatement() throws SQLException { return connection.createStatement(); } @Override public PreparedStatement prepareStatement(String sql) throws SQLException { return connection.prepareStatement(sql); } @Override public CallableStatement prepareCall(String sql) throws SQLException { return connection.prepareCall(sql); } @Override public String nativeSQL(String sql) throws SQLException { return connection.nativeSQL(sql); } @Override public boolean getAutoCommit() throws SQLException { return connection.getAutoCommit(); } @Override public boolean isClosed() throws SQLException { return connection.isClosed(); } @Override public DatabaseMetaData getMetaData() throws SQLException { return connection.getMetaData(); } @Override public void setReadOnly(boolean readOnly) throws SQLException { connection.setReadOnly(readOnly); } @Override public boolean isReadOnly() throws SQLException { return connection.isReadOnly(); } @Override public void setCatalog(String catalog) throws SQLException { connection.setCatalog(catalog); } @Override public String getCatalog() throws SQLException { return connection.getCatalog(); } @Override public void setTransactionIsolation(int level) throws SQLException { connection.setTransactionIsolation(level); } @Override public int getTransactionIsolation() throws SQLException { return connection.getTransactionIsolation(); } @Override public SQLWarning getWarnings() throws SQLException { return connection.getWarnings(); } @Override public void clearWarnings() throws SQLException { connection.clearWarnings(); } @Override public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { return connection.createStatement(resultSetType, resultSetConcurrency); } @Override public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { return connection.prepareStatement(sql, resultSetType, resultSetConcurrency); } @Override public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { return connection.prepareCall(sql, resultSetType, resultSetConcurrency); } @Override public Map<String, Class<?>> getTypeMap() throws SQLException { return connection.getTypeMap(); } @Override public void setTypeMap(Map<String, Class<?>> map) throws SQLException { connection.setTypeMap(map); } @Override public void setHoldability(int holdability) throws SQLException { connection.setHoldability(holdability); } @Override public int getHoldability() throws SQLException { return connection.getHoldability(); } @Override public Savepoint setSavepoint() throws SQLException { return connection.setSavepoint(); } @Override public Savepoint setSavepoint(String name) throws SQLException { return connection.setSavepoint(name); } @Override public void rollback(Savepoint savepoint) throws SQLException { connection.rollback(savepoint); } @Override public void releaseSavepoint(Savepoint savepoint) throws SQLException { connection.releaseSavepoint(savepoint); } @Override public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { return connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); } @Override public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { return connection.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); } @Override public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { return connection.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability); } @Override public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { return connection.prepareStatement(sql, autoGeneratedKeys); } @Override public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { return connection.prepareStatement(sql, columnIndexes); } @Override public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { return connection.prepareStatement(sql, columnNames); } @Override public Clob createClob() throws SQLException { return connection.createClob(); } @Override public Blob createBlob() throws SQLException { return connection.createBlob(); } @Override public NClob createNClob() throws SQLException { return connection.createNClob(); } @Override public SQLXML createSQLXML() throws SQLException { return connection.createSQLXML(); } @Override public boolean isValid(int timeout) throws SQLException { return connection.isValid(timeout); } @Override public void setClientInfo(String name, String value) throws SQLClientInfoException { connection.setClientInfo(name, value); } @Override public void setClientInfo(Properties properties) throws SQLClientInfoException { connection.setClientInfo(properties); } @Override public String getClientInfo(String name) throws SQLException { return connection.getClientInfo(name); } @Override public Properties getClientInfo() throws SQLException { return connection.getClientInfo(); } @Override public Array createArrayOf(String typeName, Object[] elements) throws SQLException { return connection.createArrayOf(typeName, elements); } @Override public Struct createStruct(String typeName, Object[] attributes) throws SQLException { return connection.createStruct(typeName, attributes); } @Override public void setSchema(String schema) throws SQLException { connection.setSchema(schema); } @Override public String getSchema() throws SQLException { return connection.getSchema(); } @Override public void abort(Executor executor) throws SQLException { connection.abort(executor); } @Override public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { connection.setNetworkTimeout(executor, milliseconds); } @Override public int getNetworkTimeout() throws SQLException { return connection.getNetworkTimeout(); } @Override public <T> T unwrap(Class<T> iface) throws SQLException { return connection.unwrap(iface); } @Override public boolean isWrapperFor(Class<?> iface) throws SQLException { return connection.isWrapperFor(iface); } @Override public void setAutoCommit(boolean autoCommit) throws SQLException { if (connection != null) { connection.setAutoCommit(false); } } }
import com.hz.lbtransaction.transactional.LbTransactionManager; import org.springframework.stereotype.Component; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @Component public class RequestInterceptor implements HandlerInterceptor { @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { String groupId = request.getHeader("groupId"); String transactionCount = request.getHeader("transactionCount"); LbTransactionManager.setCurrentGroupId(groupId); LbTransactionManager.setTransactionCount(Integer.valueOf(transactionCount == null ? "0" : transactionCount)); return true; } @Override public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception { } @Override public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception { } }
import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; @Configuration public class WebAppConfig extends WebMvcConfigurerAdapter { @Override public void addInterceptors(InterceptorRegistry registry) { registry.addInterceptor(new RequestInterceptor()); } }
import com.alibaba.fastjson.JSONObject; import io.netty.bootstrap.Bootstrap; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.string.StringDecoder; import io.netty.handler.codec.string.StringEncoder; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Component; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @Component public class NettyClient implements InitializingBean { public NettyClientHandler client = null; private static ExecutorService executorService = Executors.newCachedThreadPool(); @Override public void afterPropertiesSet() throws Exception { start("localhost", 8080); } public void start(String hostName, Integer port) { client = new NettyClientHandler(); Bootstrap b = new Bootstrap(); EventLoopGroup group = new NioEventLoopGroup(); b.group(group) .channel(NioSocketChannel.class) .option(ChannelOption.TCP_NODELAY, true) .handler(new ChannelInitializer<SocketChannel>() { protected void initChannel(SocketChannel socketChannel) throws Exception { ChannelPipeline pipeline = socketChannel.pipeline(); pipeline.addLast("decoder", new StringDecoder()); pipeline.addLast("encoder", new StringEncoder()); pipeline.addLast("handler", client); } }); try { b.connect(hostName, port).sync(); } catch (InterruptedException e) { e.printStackTrace(); } } public void send(JSONObject jsonObject) { try { client.call(jsonObject); } catch (Exception e) { e.printStackTrace(); } } }
package com.hz.lbtransaction.netty; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.hz.lbtransaction.transactional.LbTransaction; import com.hz.lbtransaction.transactional.LbTransactionManager; import com.hz.lbtransaction.transactional.TransactionType; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; public class NettyClientHandler extends ChannelInboundHandlerAdapter { private ChannelHandlerContext context; @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { context = ctx; } @Override public synchronized void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { System.out.println("接受数据:" + msg.toString()); JSONObject jsonObject = JSON.parseObject((String) msg); String groupId = jsonObject.getString("groupId"); String command = jsonObject.getString("command"); System.out.println("接收command:" + command); // 对事务进行操作 LbTransaction lbTransaction = LbTransactionManager.getLbTransaction(groupId); if (command.equals("rollback")) { lbTransaction.setTransactionType(TransactionType.rollback); } else if (command.equals("commit")) { lbTransaction.setTransactionType(TransactionType.commit); } lbTransaction.getTask().signalTask(); } public synchronized Object call(JSONObject data) throws Exception { context.writeAndFlush(data.toJSONString()); return null; } }
package com.hz.lbtransaction.transactional; import com.hz.lbtransaction.util.Task; public class LbTransaction { private String groupId; private String transactionId; private TransactionType transactionType; // commit-待提交,rollback-待回滚 private Task task = new Task(); public LbTransaction(String groupId, String transactionId) { this.groupId = groupId; this.transactionId = transactionId; this.task = new Task(); } public LbTransaction(String groupId, String transactionId, TransactionType transactionType) { this.groupId = groupId; this.transactionId = transactionId; this.transactionType = transactionType; } public Task getTask() { return task; } public String getGroupId() { return groupId; } public void setGroupId(String groupId) { this.groupId = groupId; } public String getTransactionId() { return transactionId; } public void setTransactionId(String transactionId) { this.transactionId = transactionId; } public TransactionType getTransactionType() { return transactionType; } public void setTransactionType(TransactionType transactionType) { this.transactionType = transactionType; } }
package com.hz.lbtransaction.transactional; import com.alibaba.fastjson.JSONObject; import com.hz.lbtransaction.netty.NettyClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.HashMap; import java.util.Map; import java.util.UUID; @Component public class LbTransactionManager { private static NettyClient nettyClient; private static ThreadLocal<LbTransaction> current = new ThreadLocal<>(); private static ThreadLocal<String> currentGroupId = new ThreadLocal<>(); private static ThreadLocal<Integer> transactionCount = new ThreadLocal<>(); @Autowired public void setNettyClient(NettyClient nettyClient) { LbTransactionManager.nettyClient = nettyClient; } public static Map<String, LbTransaction> LB_TRANSACION_MAP = new HashMap<>(); /** * 创建事务组,并且返回groupId * @return */ public static String createLbTransactionGroup() { String groupId = UUID.randomUUID().toString(); JSONObject jsonObject = new JSONObject(); jsonObject.put("groupId", groupId); jsonObject.put("command", "create"); nettyClient.send(jsonObject); System.out.println("创建事务组"); currentGroupId.set(groupId); return groupId; } /** * 创建子事务 * @param groupId * @return */ public static LbTransaction createLbTransaction(String groupId) { String transactionId = UUID.randomUUID().toString(); LbTransaction lbTransaction = new LbTransaction(groupId, transactionId); LB_TRANSACION_MAP.put(groupId, lbTransaction); current.set(lbTransaction); addTransactionCount(); System.out.println("创建事务"); return lbTransaction; } /** * 注册事务 * @param lbTransaction * @param isEnd * @param transactionType * @return */ public static LbTransaction addLbTransaction(LbTransaction lbTransaction, Boolean isEnd, TransactionType transactionType) { JSONObject jsonObject = new JSONObject(); jsonObject.put("groupId", lbTransaction.getGroupId()); jsonObject.put("transactionId", lbTransaction.getTransactionId()); jsonObject.put("transactionType", transactionType); jsonObject.put("command", "add"); jsonObject.put("isEnd", isEnd); jsonObject.put("transactionCount", LbTransactionManager.getTransactionCount()); nettyClient.send(jsonObject); System.out.println("添加事务"); return lbTransaction; } public static LbTransaction getLbTransaction(String groupId) { return LB_TRANSACION_MAP.get(groupId); } public static LbTransaction getCurrent() { return current.get(); } public static String getCurrentGroupId() { return currentGroupId.get(); } public static void setCurrentGroupId(String groupId) { currentGroupId.set(groupId); } public static Integer getTransactionCount() { return transactionCount.get(); } public static void setTransactionCount(int i) { transactionCount.set(i); } public static Integer addTransactionCount() { int i = (transactionCount.get() == null ? 0 : transactionCount.get()) + 1; transactionCount.set(i); return i; } }
package com.hz.lbtransaction.transactional; public enum TransactionType { commit, rollback; }
package com.hz.lbtransaction.util; import com.hz.lbtransaction.transactional.LbTransactionManager; import org.apache.http.HttpStatus; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; import org.apache.http.util.EntityUtils; public class HttpClient { public static String get(String url) { String result = ""; try { CloseableHttpClient httpClient = HttpClients.createDefault(); HttpGet httpGet = new HttpGet(url); httpGet.addHeader("Content-type", "application/json"); httpGet.addHeader("groupId", LbTransactionManager.getCurrentGroupId()); httpGet.addHeader("transactionCount", String.valueOf(LbTransactionManager.getTransactionCount())); CloseableHttpResponse response = httpClient.execute(httpGet); if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) { result = EntityUtils.toString(response.getEntity(), "utf-8"); } response.close(); } catch (Exception e) { e.printStackTrace(); } return result; } }
package com.hz.lbtransaction.util; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; public class Task { private Lock lock = new ReentrantLock(); private Condition condition = lock.newCondition(); public void waitTask() { try { lock.lock(); condition.await(); } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.unlock(); } } public void signalTask() { lock.lock(); condition.signal(); lock.unlock(); } }