Redis实现分布式锁
之前写过一篇博客,里面吭哧吭哧半天,使用Redis实现了一个分布式锁。
今天闲来没事看源码,突然发现redis set命令的用法可以直接指定nx和ex,文档中没有明说这是个原子方法,但是后面给出了一个例子使用set nx ex的方法实现了redis锁。
感觉应该是原子性的,挺好。
相比这篇文章里的方法http://www.cnblogs.com/kangoroo/p/7298370.html,有两个优点:
1)简单,之前的那篇文章里使用getSet方法,折腾了一顿,就是怕setnx之后expire成功不了,这里直接原子性的话,省事多了。
2)解决了超时误删锁引入的竞态问题,之前我们在value中保留了时间,这里我们可以保留一个uuid,判断是自己的锁再删除,避免误删。
直接粘在下面,以后实现了实例再回来。api出处:http://doc.redisfans.com/string/set.html
命令 SET resource-name anystring NX EX max-lock-time 是一种在 Redis 中实现锁的简单方法。
客户端执行以上的命令:
- 如果服务器返回 OK ,那么这个客户端获得锁。
- 如果服务器返回 NIL ,那么客户端获取锁失败,可以在稍后再重试。
设置的过期时间到达之后,锁将自动释放。
可以通过以下修改,让这个锁实现更健壮:
- 不使用固定的字符串作为键的值,而是设置一个不可猜测(non-guessable)的长随机字符串,作为口令串(token)。
- 不使用 DEL 命令来释放锁,而是发送一个 Lua 脚本,这个脚本只在客户端传入的值和键的口令串相匹配时,才对键进行删除。
这两个改动可以防止持有过期锁的客户端误删现有锁的情况出现。
以下是一个简单的解锁脚本示例:
if redis.call("get",KEYS[1]) == ARGV[1]
then
return redis.call("del",KEYS[1])
else
return 0
end
这个脚本可以通过 EVAL ...script... 1 resource-name token-value 命令来调用。
提供一个分布式锁的实现
# -*- coding:utf-8 -*- from __future__ import print_function import redis import time import multiprocessing from contextlib import contextmanager as _contextmanager # 简单创建redis的客户端 r = redis.Redis(host='localhost', port=6379, db=0) # 分布式锁实现 # finally中验证本线程是否获得锁, 是为了防止误删别的线程获取的锁 @_contextmanager def dist_lock(client, key): dist_lock_key = 'lock:%s' % key is_acquire_lock = False try: is_acquire_lock = _acquire_lock(client, dist_lock_key) yield finally: if is_acquire_lock: _release_lock(client, dist_lock_key) # 尝试获取锁 # 成功: 返回True, 失败: 抛出异常 # 使用set nx ex原语, 使得setnx和expire操作成为原子操作 def _acquire_lock(client, key): is_lock = r.set(key, 1, nx=True, ex=10) if not is_lock: raise Exception("already locked!") return is_lock # 释放锁 # 简单删除key # 如果删除失败, 锁也会通过expire时间超时 def _release_lock(client, key): client.delete(key) # 测试函数 # 获取锁成功, 打印成功, 并持有锁3s # 获取锁失败, 直接打印 def func(): while 1: try: with dist_lock(r, 'key'): print("*", end='') time.sleep(3) except Exception, ex: print('!', end='') # 多进程启动 # 这种模式下, 线程锁无效, 可以验证分布式锁 process_list = list() for i in range(2): process_list.append(multiprocessing.Process(target=func)) for process in process_list: process.start() for process in process_list: process.join()