LB中使用到的一致性Hash算法的简单实现
关于一致性hash算法,可以参考这篇文章: https://zhuanlan.zhihu.com/p/34985026
1、类的Diagram
2、代码实现
2.1、Node类,每个Node代表集群里面的一个节点或者具体说是某一台物理机器;
package consistencyhash; import lombok.Getter; import lombok.RequiredArgsConstructor; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import lombok.ToString; /** * @author xfyou * @date 2019/9/2 */ @Getter @RequiredArgsConstructor @ToString(exclude = "data") public class Node { private final String domain; private final String ip; private final Map<String, Object> data = new ConcurrentHashMap<>(); public <T> void put(String key, T value) { data.put(key, value); } public void remove(String key) { data.remove(key); } public <T> T get(String key) { return (T) data.get(key); } }
2.2、 AbstractCluster,cluster抽象类,集群抽象类;
package consistencyhash; import java.util.ArrayList; import java.util.List; /** * @author xfyou * @date 2019/9/2 */ public abstract class AbstractCluster { protected final List<Node> nodes; public AbstractCluster() { this.nodes = new ArrayList<>(); } public abstract void addNode(Node node); public abstract void removeNode(Node node); public abstract Node get(String key); }
2.3、Cluster类,集群类,一致性hash算法的具体实现类
package consistencyhash; import com.google.common.hash.Hashing; import java.nio.charset.StandardCharsets; import java.util.SortedMap; import java.util.TreeMap; import java.util.stream.IntStream; /** * @author xfyou * @date 2019/9/2 */ public class ConsistencyHashCluster extends AbstractCluster { private final SortedMap<Long, Node> virNodes = new TreeMap<>(); private static final int VIR_NODE_COUNT = 160; @Override public void addNode(Node node) { this.nodes.add(node); IntStream.range(0, VIR_NODE_COUNT / 4).forEach(i -> { byte[] digest = Hashing.md5().hashBytes((node.toString() + i).getBytes(StandardCharsets.UTF_8)).asBytes(); for (int h = 0; h < 4; h++) { virNodes.put(hash(digest, h), node); } }); } /** * 物理节点被删除的话,这个物理节点所对应的所有的虚拟节点也同时被删 */ @Override public void removeNode(Node node) { nodes.removeIf(o -> node.getIp().equals(o.getIp())); IntStream.range(0, VIR_NODE_COUNT / 4).forEach(i -> { byte[] digest = Hashing.md5().hashBytes((node.toString() + i).getBytes(StandardCharsets.UTF_8)).asBytes(); for (int h = 0; h < 4; h++) { virNodes.remove(hash(digest, h)); } }); } @Override public Node get(String key) { long hash = calHash(key); SortedMap<Long, Node> subMap = hash >= virNodes.lastKey() ? virNodes.tailMap(0L) : virNodes.tailMap(hash); if (subMap.isEmpty()) { return virNodes.get(virNodes.firstKey()); } System.out.println("hash=" + hash + ",subMap.firstKey=" + subMap.firstKey()); return subMap.get(subMap.firstKey()); } private long calHash(String key) { byte[] keyBytes = Hashing.md5().hashBytes(key.getBytes(StandardCharsets.UTF_8)).asBytes(); return hash(keyBytes, 0); } /** * 取MD5后16个字节中的连续的4个字节并通过移位操作来转换为 long 类型的 hash 值 */ private long hash(byte[] digest, int number) { return (((long) (digest[3 + number * 4] & 0xFF) << 24) | ((long) (digest[2 + number * 4] & 0xFF) << 16) | ((long) (digest[1 + number * 4] & 0xFF) << 8) | (digest[number * 4] & 0xFF)) & 0xFFFFFFFFL; } }
2.4、Test类,测试类
package consistencyhash; import java.util.stream.IntStream; /** * @author xfyou * @date 2019/9/2 */ public class Test { private static final int DATA_CONT = 20; private static final String PRE_KEY = "PRE_KEY"; public static void main(String[] args) { AbstractCluster cluster = new ConsistencyHashCluster(); cluster.addNode(new Node("c1.yywang.info", "192.168.0.1")); cluster.addNode(new Node("c2.yywang.info", "192.168.0.2")); cluster.addNode(new Node("c3.yywang.info", "192.168.0.3")); IntStream.range(0, DATA_CONT).forEach(index -> { Node node = cluster.get(PRE_KEY + index); node.put(PRE_KEY + index, "cached_data"); }); System.out.println("数据分布情况:"); cluster.nodes.forEach(node -> { System.out.println("IP:" + node.getIp() + ",数据量:" + node.getData().size()); }); cluster.removeNode(new Node("c1.yywang.info", "192.168.0.1")); // 查询命中率,如果没有命中则需要从后端 DB 中查询 long hitCount = IntStream.range(0, DATA_CONT).filter(index -> cluster.get(PRE_KEY + index).get(PRE_KEY + index) != null).count(); System.out.println("hitCount=" + hitCount); System.out.println("缓存命中率:" + hitCount * 1f / DATA_CONT); } }