基于SpringDataRedis实现一个分布式锁工具类
基础依赖
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-data-redis</artifactId> <version>2.x.x.RELEASE</version> </dependency>
核心类 DRedisLock
package com.idanchuang.component.redis.util.task; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; /** * 基于Redis的分布式锁(线程内可重入) * @author yjy * @date 2019/11/27 11:07 **/ public class DRedisLock implements Lock { private static final Logger log = LoggerFactory.getLogger(DRedisLock.class); /** 默认的锁超时时间 */ public final static long DEFAULT_TIMEOUT = 30000L; /** 锁key前缀 */ public final static String LOCK_PREFIX = "D_LOCK_"; /** 默认的获取锁超时时间 */ public final static long DEFAULT_TRY_LOCK_TIMEOUT = 10000L; /** 等待锁时, 自旋尝试的周期, 默认10毫秒 */ public final static long DEFAULT_LOOP_INTERVAL = 10L; /** 序列值, 用于确保锁value的唯一性 */ private static AtomicLong SERIAL_NUM; /** 最大序列值 */ private static long MAX_SERIAL; /** 本机host */ private static String CURRENT_HOST; /** StringRedisTemplate */ private final StringRedisTemplate redisTemplate; /** 锁Key */ private final String lockKey; /** 锁超时时间(单位毫秒) */ private final long timeout; /** 等待锁时, 自旋尝试的周期(单位毫秒) */ private final long loopInterval; /** 主机+线程id */ private final String hostThreadId; /** 锁定值 */ private final String lockValue; /** 是否重入 */ private boolean reentrant = false; /** 是否持有锁 */ private boolean locked = false; static { try { SERIAL_NUM = new AtomicLong(0); MAX_SERIAL = 99999999L; CURRENT_HOST = InetAddress.getLocalHost().getHostAddress(); } catch (UnknownHostException e) { CURRENT_HOST = UUID.randomUUID().toString(); log.warn("DRedisLock > can not get local host, use uuid: {}", CURRENT_HOST); } } public DRedisLock(String lockName) { this(lockName, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL); } public DRedisLock(String lockName, long timeout) { this(lockName, timeout, DEFAULT_LOOP_INTERVAL); } public DRedisLock(String lockName, long timeout, long loopInterval) { if (lockName == null) { throw new IllegalArgumentException("lockName must assigned"); } this.redisTemplate = SpringUtil.getBean(StringRedisTemplate.class); this.lockKey = LOCK_PREFIX + lockName; this.timeout = timeout; this.loopInterval = loopInterval; this.hostThreadId = CURRENT_HOST + ":" + Thread.currentThread().getId(); this.lockValue = this.hostThreadId + ":" + getNextSerial(); } /** * 获取锁, 如果锁被持有, 将一直等待, 直到超出默认的的DEFAULT_TRY_LOCK_TIMEOUT */ @Override public void lock() { try { if (!tryLock(DEFAULT_TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS)) { throw new RuntimeException("try lock timeout, lockKey: " + this.lockKey); } } catch (InterruptedException e) { throw new RuntimeException(e); } } /** * 尝试获取锁, 如果锁被持有, 则等待相应的时间(等待锁时可被中断) * @throws InterruptedException 被中断等待 */ @Override public void lockInterruptibly() throws InterruptedException { if (!tryLock(DEFAULT_TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS, true)) { throw new RuntimeException("try lock timeout, lockKey: " + this.lockKey); } } /** * 尝试获取锁, 只会立即获取一次, 如果锁被占用, 则返回false, 获取成功则返回true * @return 是否成功获取锁 */ @Override public boolean tryLock() { try { Boolean success = setIfAbsent(this.lockKey, this.lockValue, this.timeout / 1000); if (success != null && success) { locked = true; log.debug("Lock success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue); return true; } else { // 如果持有锁的是当前线程, 则重入 String script = "local val,ttl=ARGV[1],ARGV[2] "; script += "if redis.call('EXISTS', KEYS[1])==1 then local curValue = redis.call('GET', KEYS[1]) if string.find(curValue, val)==1 then local curTtl = redis.call('TTL', KEYS[1]) redis.call('EXPIRE', KEYS[1], curTtl + ttl) return true else return false end else return false end"; DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>(); redisScript.setResultType(Boolean.class); redisScript.setScriptText(script); List<String> keys = new ArrayList<>(); keys.add(this.lockKey); success = redisTemplate.execute(redisScript, keys, this.hostThreadId, String.valueOf(Math.max(this.timeout / 1000L, 1))); if (success != null && success) { this.reentrant = true; locked = true; log.debug("Lock reentrant success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue); return true; } } } catch (Exception e) { log.error("tryLock error, do unlock, lockKey: {}, lockValue: {}", this.lockKey, lockValue, e); unlock(); } return false; } /** * 使用lua脚本的方式实现setIfAbsent, 因为当业务应用使用了redisson时, 直接使用template的setIfAbsent返回值为null * @param key key * @param value 值 * @param timeoutSecond 超时时间 * @return 是否成功设值 */ private Boolean setIfAbsent(String key, String value, long timeoutSecond) { String script = "local val,ttl=ARGV[1],ARGV[2] "; script += "if redis.call('EXISTS', KEYS[1])==1 then return false else redis.call('SET', KEYS[1], ARGV[1]) redis.call('EXPIRE', KEYS[1], ARGV[2]) return true end"; DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>(); redisScript.setResultType(Boolean.class); redisScript.setScriptText(script); List<String> keys = new ArrayList<>(); keys.add(key); return redisTemplate.execute(redisScript, keys, value, String.valueOf(timeoutSecond)); } /** * 尝试获取锁, 如果锁被占用, 则持续尝试获取, 直到超过指定的time时间 * @param time 等待锁的时间 * @param unit time的单位 * @return 是否成功获取锁 * @throws InterruptedException 被中断 */ @Override public boolean tryLock(long time, TimeUnit unit) throws InterruptedException { return tryLock(time, unit, false); } /** * 尝试获取锁, 如果锁被占用, 则持续尝试获取, 直到超过指定的time时间 * @param time 等待锁的时间 * @param unit time的单位 * @param interruptibly 等待是否可被中断 * @return 是否成功获取锁 * @throws InterruptedException 被中断 */ private boolean tryLock(long time, TimeUnit unit, boolean interruptibly) throws InterruptedException { long millis = unit.convert(time, TimeUnit.MILLISECONDS); long current = System.currentTimeMillis(); do { if (interruptibly && Thread.interrupted()) { throw new RuntimeException("tryLock interrupted"); } if (tryLock()) { return true; } Thread.sleep(loopInterval); } while (System.currentTimeMillis() - current < millis); return false; } /** * 释放锁 */ @Override public void unlock() { try { if (!locked) { return; } if (this.reentrant) { log.debug("Unlock reentrant success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue); return; } // 使用lua脚本处理锁判断和释放 String script = "if redis.call('get', KEYS[1]) == ARGV[1] then redis.call('del', KEYS[1]) return true else return false end"; DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>(); redisScript.setResultType(Boolean.class); redisScript.setScriptText(script); Boolean res = this.redisTemplate.execute(redisScript, Collections.singletonList(this.lockKey), this.lockValue); if (res != null && res) { locked = false; log.debug("Unlock success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue); return; } } catch (Exception e) { log.error("Unlock error", e); } log.warn("Unlock failed, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue); } @Override public Condition newCondition() { throw new UnsupportedOperationException(); } /** * @return 下一个序列值 */ private static synchronized long getNextSerial() { long serial = SERIAL_NUM.incrementAndGet(); if (serial > MAX_SERIAL) { serial = serial - MAX_SERIAL; SERIAL_NUM.set(serial); } return serial; } public static AtomicLong getSerialNum() { return SERIAL_NUM; } public static long getMaxSerial() { return MAX_SERIAL; } public static String getCurrentHost() { return CURRENT_HOST; } public String getLockKey() { return lockKey; } public long getTimeout() { return timeout; } public long getLoopInterval() { return loopInterval; } public String getHostThreadId() { return hostThreadId; } public String getLockValue() { return lockValue; } public boolean isReentrant() { return reentrant; } @Override public String toString() { return "DRedisLock{" + "lockKey='" + lockKey + '\'' + ", timeout=" + timeout + ", loopInterval=" + loopInterval + ", hostThreadId='" + hostThreadId + '\'' + ", lockValue='" + lockValue + '\'' + ", reentrant=" + reentrant + '}'; } }
工具类 SpringUtil
package com.idanchuang.component.core.util; import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.stereotype.Component; @Component public class SpringUtil implements ApplicationContextAware { private static ApplicationContext applicationContext; public static <T> T getBean(Class<T> tClass) { checkState(); return applicationContext.getBean(tClass); } public static <T> T getBean(String beanName) { checkState(); return (T)applicationContext.getBean(beanName); } public static <T> T getBean(String beanName, Class<T> requiredType) { checkState(); return applicationContext.getBean(beanName, requiredType); } private static void checkState() { if (SpringUtil.applicationContext == null) { throw new IllegalStateException("SpringUtil applicationContext is unready"); } } @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { SpringUtil.applicationContext = applicationContext; } }
至此, 我们已经可以开始使用分布式锁的功能啦, 如下
DRedisLock lock = new DRedisLock("testa"); try { lock.lock(); int b = a; a = b + 1; System.out.println(a); } finally { lock.unlock(); }
我觉得这样使用起来也太麻烦了, 还要自己实例化lock对象来加锁和释放锁, 如果忘记释放的话问题就很大, 所以我又封装了一个 DRedisLocks 类
package com.idanchuang.component.redis.util; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.concurrent.*; import static com.idanchuang.component.redis.util.DRedisLock.*; /** * 基于Redis的分布式锁 * @author yjy * @date 2019/11/27 11:07 **/ public class DRedisLocks { private static final Logger log = LoggerFactory.getLogger(DRedisLocks.class); /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param runnable 要执行的代码块 */ public static void runWithLock(String lockName, Runnable runnable) { runWithLock(lockName, DEFAULT_TRY_LOCK_TIMEOUT, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, runnable); } /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param callable 要执行的代码块 * @param <V> 返回类型 * @return 执行结果 */ public static <V> V runWithLock(String lockName, Callable<V> callable) { return runWithLock(lockName, DEFAULT_TRY_LOCK_TIMEOUT, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, callable); } /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param tryTimeout 获取锁的超时时间 * @param runnable 要执行的代码块 */ public static void runWithLock(String lockName, long tryTimeout, Runnable runnable) { runWithLock(lockName, tryTimeout, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, runnable); } /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param tryTimeout 获取锁的超时时间 * @param callable 要执行的代码块 * @param <V> 返回类型 * @return 执行结果 */ public static <V> V runWithLock(String lockName, long tryTimeout, Callable<V> callable) { return runWithLock(lockName, tryTimeout, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, callable); } /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param tryTimeout 获取锁的超时时间 * @param lockTimeout 持有锁的超时时间 * @param runnable 要执行的代码块 */ public static void runWithLock(String lockName, long tryTimeout, long lockTimeout, Runnable runnable) { runWithLock(lockName, tryTimeout, lockTimeout, DEFAULT_LOOP_INTERVAL, runnable); } /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param tryTimeout 获取锁的超时时间 * @param lockTimeout 持有锁的超时时间 * @param callable 要执行的代码块 * @param <V> 返回类型 * @return 执行结果 */ public static <V> V runWithLock(String lockName, long tryTimeout, long lockTimeout, Callable<V> callable) { return runWithLock(lockName, tryTimeout, lockTimeout, DEFAULT_LOOP_INTERVAL, callable); } /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param tryTimeout 获取锁的超时时间 * @param lockTimeout 持有锁的超时时间 * @param loopInterval 自旋获取锁间隔 * @param runnable 要执行的代码块 */ public static void runWithLock(String lockName, long tryTimeout, long lockTimeout, long loopInterval, Runnable runnable) { Callable<Void> callable = () -> { runnable.run(); return null; }; runWithLock(lockName, tryTimeout, lockTimeout, loopInterval, callable); } /** * 执行分布式同步代码块 * @param lockName 锁名称 * @param tryTimeout 获取锁的超时时间 * @param lockTimeout 持有锁的超时时间 * @param loopInterval 自旋获取锁间隔 * @param callable 要执行的代码块 * @param <V> 返回类型 * @return 执行结果 */ public static <V> V runWithLock(String lockName, long tryTimeout, long lockTimeout, long loopInterval, Callable<V> callable) { DRedisLock lock = new DRedisLock(lockName, lockTimeout, loopInterval); log.debug("Init DRedisLock > {}", lock); try { if (lock.tryLock(tryTimeout, TimeUnit.MILLISECONDS)) { log.debug("Lock successful, lockName: {}", lockName); return callable.call(); } throw new RuntimeException("Get redisLock failed, lockName: " + lockName); } catch (RuntimeException e) { throw e; } catch (Exception e) { throw new RuntimeException(e); } finally { lock.unlock(); log.debug("Unlock successful, lockName: {}", lockName); } } }
现在我们就可以通过下面这种方式来使用分布式锁了, 而且不用自己手动加锁释放锁, 轻松了不少
DRedisLocks.runWithLock("testa", () -> { int b = a; a = b + 1; System.out.println(a); });
那么针对整个方法的同步锁, 这样使用还是不够优雅, 能不能做到一个注解就实现分布式锁的能力, 答案当然是可以的, 我又新建了几个类
RedisLock 注解类
package com.idanchuang.component.redis.annotation; import java.lang.annotation.*; /** * @author yjy * @date 2020/5/8 9:53 **/ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Inherited public @interface RedisLock { /** 锁名称 如果不指定,则为类名:方法名 */ String value() default ""; /** 获取锁的超时时间 ms */ long tryTimeout() default 10000L; /** 持有锁的超时时间 ms */ long lockTimeout() default 30000L; /** 自旋获取锁间隔 ms */ long loopInterval() default 10L; /** 自定义业务key (解析后追加在锁名称中) */ String[] keys() default {}; /** 错误提示信息 */ String errMessage() default ""; }
RedisLockAspect AOP配置类
package com.idanchuang.component.redis.aspect; import com.idanchuang.component.base.exception.common.ErrorCode; import com.idanchuang.component.base.exception.core.ExFactory; import com.idanchuang.component.redis.annotation.RedisLock; import com.idanchuang.component.redis.helper.BusinessKeyHelper; import com.idanchuang.component.redis.util.DRedisLock; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Pointcut; import org.aspectj.lang.reflect.MethodSignature; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import java.lang.reflect.Method; import java.util.concurrent.TimeUnit; /** * Aspect for methods with {@link RedisLock} annotation. * * @author yjy */ @Aspect @Component public class RedisLockAspect { private static final Logger log = LoggerFactory.getLogger(RedisLockAspect.class); @Pointcut("@annotation(com.idanchuang.component.redis.annotation.RedisLock)") public void redisLockAnnotationPointcut() { } @Around("redisLockAnnotationPointcut()") public Object invokeWithRedisLock(ProceedingJoinPoint pjp) throws Throwable { Method originMethod = resolveMethod(pjp); RedisLock annotation = originMethod.getAnnotation(RedisLock.class); if (annotation == null) { // Should not go through here. throw new IllegalStateException("Wrong state for RedisLock annotation"); } DRedisLock lock = null; String lockName = getName(annotation.value(), originMethod); lockName += BusinessKeyHelper.getKeyName(pjp, annotation.keys()); try { lock = new DRedisLock(lockName, annotation.lockTimeout(), annotation.loopInterval()); // 获取锁, 如果被占用则等待, 直到获取到锁, 或则等待超时 if (lock.tryLock(annotation.tryTimeout(), TimeUnit.MILLISECONDS)) { return pjp.proceed(); } else { String msg = "Get redisLock failed, lockName: " + lockName; log.warn(msg); throw ExFactory.throwWith(ErrorCode.CONFLICT, !StringUtils.isEmpty(annotation.errMessage()) ? msg : annotation.errMessage()); } } finally { // 重点: 释放锁 if (lock != null) { lock.unlock(); } } } /** * 获取lockName前缀 * * @param lockName * @param originMethod * @return java.lang.String * @author sxp * @date 2020/7/3 11:06 */ private String getName(String lockName, Method originMethod) { // if 未指定lockName, 则默认取 类名:方法名 if (StringUtils.isEmpty(lockName)) { return originMethod.getDeclaringClass().getSimpleName() + ":" + originMethod.getName(); } else { return lockName; } } private Method resolveMethod(ProceedingJoinPoint joinPoint) { MethodSignature signature = (MethodSignature) joinPoint.getSignature(); Class<?> targetClass = joinPoint.getTarget().getClass(); Method method = getDeclaredMethodFor(targetClass, signature.getName(), signature.getMethod().getParameterTypes()); if (method == null) { throw new IllegalStateException("Cannot resolve target method: " + signature.getMethod().getName()); } return method; } /** * Get declared method with provided name and parameterTypes in given class and its super classes. * All parameters should be valid. * * @param clazz class where the method is located * @param name method name * @param parameterTypes method parameter type list * @return resolved method, null if not found */ private Method getDeclaredMethodFor(Class<?> clazz, String name, Class<?>... parameterTypes) { try { return clazz.getDeclaredMethod(name, parameterTypes); } catch (NoSuchMethodException e) { Class<?> superClass = clazz.getSuperclass(); if (superClass != null) { return getDeclaredMethodFor(superClass, name, parameterTypes); } } return null; } }
至此, 我们已经可以通过注解来实现接口的分布式锁能力
/** * @author yjy * @date 2020/5/8 10:21 **/ @Component public class LockService { private static int a = 0; @RedisLock(value = "customLockName:88888222", lockTimeout = 60000L, tryTimeout = 20000L) public void doSomething() { a ++; try { Thread.sleep(1000L); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println("a: " + a); } }
以上简单的介绍了我们实现的Redis分布式锁, 其实它的功能不止介绍的这些
它还支持线程内可重入, 支持超时自动释放锁, 注解模式支持解析参数对象来作为锁资源 等等
好了, 今天就到这里吧, 拜拜~