zookeeper 实现分布式锁
package com.concurrent; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import org.apache.zookeeper.CreateMode; import org.apache.zookeeper.KeeperException; import org.apache.zookeeper.WatchedEvent; import org.apache.zookeeper.Watcher; import org.apache.zookeeper.ZooDefs; import org.apache.zookeeper.ZooKeeper; import org.apache.zookeeper.data.Stat; public class DistributedLock implements Lock, Watcher { private ZooKeeper zk; private String root = "/locks";// 根 private String lockName;// 竞争资源的标志 private String waitNode;// 等待前一个锁 private String myZnode;// 当前锁 private CountDownLatch latch;// 计数器 private int sessionTimeout = 30000; private List<Exception> exception = new ArrayList<Exception>(); /** * 创建分布式锁,使用前请确认config配置的zookeeper服务可用 * * @param config * 127.0.0.1:2181 * @param lockName * 竞争资源标志,lockName中不能包含单词lock */ public DistributedLock(String config, String lockName) { this.lockName = lockName; // 创建一个与服务器的连接 try { zk = new ZooKeeper(config, sessionTimeout, this); Stat stat = zk.exists(root, false); if (stat == null) { // 创建根节点 zk.create(root, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT); } } catch (IOException e) { exception.add(e); } catch (KeeperException e) { exception.add(e); } catch (InterruptedException e) { exception.add(e); } } /** * zookeeper节点的监视器 */ public void process(WatchedEvent event) { if (this.latch != null) { this.latch.countDown(); } } public void lock() { if (exception.size() > 0) { throw new LockException(exception.get(0)); } try { if (this.tryLock()) { System.out.println("Thread " + Thread.currentThread().getId() + " " + myZnode + " get lock true"); return; } else { waitForLock(waitNode, sessionTimeout);// 等待锁 } } catch (KeeperException e) { throw new LockException(e); } catch (InterruptedException e) { throw new LockException(e); } } public boolean tryLock() { try { String splitStr = "_lock_"; if (lockName.contains(splitStr)) throw new LockException("lockName can not contains \\u000B"); // 创建临时子节点 myZnode = zk.create(root + "/" + lockName + splitStr, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL); System.out.println(myZnode + " is created "); // 取出所有子节点 List<String> subNodes = zk.getChildren(root, false); // 取出所有lockName的锁 List<String> lockObjNodes = new ArrayList<String>(); for (String node : subNodes) { String _node = node.split(splitStr)[0]; if (_node.equals(lockName)) { lockObjNodes.add(node); } } Collections.sort(lockObjNodes); System.out.println(myZnode + "==" + lockObjNodes.get(0)); if (myZnode.equals(root + "/" + lockObjNodes.get(0))) { // 如果是最小的节点,则表示取得锁 return true; } // 如果不是最小的节点,找到比自己小1的节点 String subMyZnode = myZnode.substring(myZnode.lastIndexOf("/") + 1); waitNode = lockObjNodes.get(Collections.binarySearch(lockObjNodes, subMyZnode) - 1); } catch (KeeperException e) { throw new LockException(e); } catch (InterruptedException e) { throw new LockException(e); } return false; } public boolean tryLock(long time, TimeUnit unit) { try { if (this.tryLock()) { return true; } return waitForLock(waitNode, time); } catch (Exception e) { e.printStackTrace(); } return false; } private boolean waitForLock(String lower, long waitTime) throws InterruptedException, KeeperException { Stat stat = zk.exists(root + "/" + lower, true); // 判断比自己小一个数的节点是否存在,如果不存在则无需等待锁,同时注册监听 if (stat != null) { System.out.println("Thread " + Thread.currentThread().getId() + " waiting for " + root + "/" + lower); this.latch = new CountDownLatch(1); this.latch.await(waitTime, TimeUnit.MILLISECONDS); this.latch = null; } return true; } public void unlock() { try { System.out.println("unlock " + myZnode); zk.delete(myZnode, -1); myZnode = null; zk.close(); } catch (InterruptedException e) { e.printStackTrace(); } catch (KeeperException e) { e.printStackTrace(); } } public void lockInterruptibly() throws InterruptedException { this.lock(); } public Condition newCondition() { return null; } public class LockException extends RuntimeException { private static final long serialVersionUID = 1L; public LockException(String e) { super(e); } public LockException(Exception e) { super(e); } } }
package com.concurrent; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; public class ConcurrentTest { private CountDownLatch startSignal = new CountDownLatch(1);// 开始阀门 private CountDownLatch doneSignal = null;// 结束阀门 private CopyOnWriteArrayList<Long> list = new CopyOnWriteArrayList<Long>(); private AtomicInteger err = new AtomicInteger();// 原子递增 private ConcurrentTask[] task = null; public ConcurrentTest(ConcurrentTask... task) { this.task = task; if (task == null) { System.out.println("task can not null"); System.exit(1); } doneSignal = new CountDownLatch(task.length); start(); } /** * @param args * @throws ClassNotFoundException */ private void start() { // 创建线程,并将所有线程等待在阀门处 createThread(); // 打开阀门 startSignal.countDown();// 递减锁存器的计数,如果计数到达零,则释放所有等待的线程 try { doneSignal.await();// 等待所有线程都执行完毕 } catch (InterruptedException e) { e.printStackTrace(); } // 计算执行时间 getExeTime(); } /** * 初始化所有线程,并在阀门处等待 */ private void createThread() { long len = doneSignal.getCount(); for (int i = 0; i < len; i++) { final int j = i; new Thread(new Runnable() { public void run() { try { startSignal.await();// 使当前线程在锁存器倒计数至零之前一直等待 long start = System.currentTimeMillis(); task[j].run(); long end = (System.currentTimeMillis() - start); list.add(end); } catch (Exception e) { err.getAndIncrement();// 相当于err++ } doneSignal.countDown(); } }).start(); } } /** * 计算平均响应时间 */ private void getExeTime() { int size = list.size(); List<Long> _list = new ArrayList<Long>(size); _list.addAll(list); Collections.sort(_list); long min = _list.get(0); long max = _list.get(size - 1); long sum = 0L; for (Long t : _list) { sum += t; } long avg = sum / size; System.out.println("min: " + min); System.out.println("max: " + max); System.out.println("avg: " + avg); System.out.println("err: " + err.get()); } public interface ConcurrentTask { void run(); } }
package com.concurrent; import com.concurrent.ConcurrentTest.ConcurrentTask; public class ZkTest { public static void main(String[] args) { Runnable task1 = new Runnable() { public void run() { DistributedLock lock = null; try { lock = new DistributedLock("127.0.0.1:2181", "test1"); // lock = new DistributedLock("127.0.0.1:2182","test2"); lock.lock(); Thread.sleep(3000); System.out.println("===Thread " + Thread.currentThread().getId() + " running"); } catch (Exception e) { e.printStackTrace(); } finally { if (lock != null) lock.unlock(); } } }; new Thread(task1).start(); try { Thread.sleep(1000); } catch (InterruptedException e1) { e1.printStackTrace(); } ConcurrentTask[] tasks = new ConcurrentTask[6]; for (int i = 0; i < tasks.length; i++) { ConcurrentTask task3 = new ConcurrentTask() { public void run() { DistributedLock lock = null; try { lock = new DistributedLock("127.0.0.1:2181", "test2"); lock.lock(); System.out.println("Thread " + Thread.currentThread().getId() + " running"); } catch (Exception e) { e.printStackTrace(); } finally { lock.unlock(); } } }; tasks[i] = task3; } new ConcurrentTest(tasks); } }