1.一致性hash算法的介绍
一致性 hash 算法由麻省理工学院的 Karger 及其合作者于1997年提出的,算法提出之初是用于大规模缓存系统的负载均衡。它的工作过程是这样的,首先根据 ip 或者其他的信息为缓存节点生成一个 hash,并将这个 hash 投射到 [0, 232 - 1] 的圆环上。当有查询或写入请求时,则为缓存项的 key 生成一个 hash 值。然后查找第一个大于或等于该 hash 值的缓存节点,并到这个节点中查询或写入缓存项。如果当前节点挂了,则在下一次查询或写入缓存时,为缓存项查找另一个大于其 hash 值的缓存节点即可。大致效果如下图所示,每个缓存节点在圆环上占据一个位置。如果缓存项的 key 的 hash 值小于缓存节点 hash 值,则到该缓存节点中存储或读取缓存项。比如下面绿色点对应的缓存项将会被存储到 cache-2 节点中。由于 cache-3 挂了,原本应该存到该节点中的缓存项最终会存储到 cache-4 节点中。
下面来看看一致性 hash 在 Dubbo 中的应用。我们把上图的缓存节点替换成 Dubbo 的服务提供者,于是得到了下图:
这里相同颜色的节点均属于同一个服务提供者,比如 Invoker1-1,Invoker1-2,……, Invoker1-160。这样做的目的是通过引入虚拟节点,让 Invoker 在圆环上分散开来,避免数据倾斜问题。所谓数据倾斜是指,由于节点不够分散,导致大量请求落到了同一个节点上,而其他节点只会接收到了少量请求的情况。比如:
如上,由于 Invoker-1 和 Invoker-2 在圆环上分布不均,导致系统中75%的请求都会落到 Invoker-1 上,只有 25% 的请求会落到 Invoker-2 上。解决这个问题办法是引入虚拟节点,通过虚拟节点均衡各个节点的请求量。
2.实现方式
2.1Java代码实现方式
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.SortedMap;
import java.util.TreeMap;
public class ConsistentHashing {
// 使用 TreeMap 存储虚拟节点
private final SortedMap<Integer, String> circle = new TreeMap<>();
// 每个真实节点的虚拟节点数量
private final int numberOfReplicas;
public ConsistentHashing(int numberOfReplicas, String[] nodes) {
this.numberOfReplicas = numberOfReplicas;
for (String node : nodes) {
addNode(node);
}
}
private int hash(String key) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
md.update(key.getBytes());
byte[] digest = md.digest();
return Math.abs(digest[0]) + ((digest[1] & 0xFF) << 8) + ((digest[2] & 0xFF) << 16) + ((digest[3] & 0xFF) << 24);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Hashing algorithm not found", e);
}
}
public void addNode(String node) {
for (int i = 0; i < numberOfReplicas; i++) {
String virtualNode = node + "&&&" + i;
circle.put(hash(virtualNode), virtualNode);
}
}
public void removeNode(String node) {
for (int i = 0; i < numberOfReplicas; i++) {
String virtualNode = node + "&&&" + i;
circle.remove(hash(virtualNode));
}
}
public String getNode(String key) {
if (circle.isEmpty()) {
return null;
}
int hashKey = hash(key);
SortedMap<Integer, String> tailMap = circle.tailMap(hashKey);
Integer nodeHash = tailMap.isEmpty() ? circle.firstKey() : tailMap.firstKey();
return circle.get(nodeHash);
}
// Example usage
public static void main(String[] args) {
String[] nodes = {"node1", "node2", "node3"};
ConsistentHashing ch = new ConsistentHashing(3, nodes);
String key = "my_key";
System.out.println("Node for key '" + key + "': " + ch.getNode(key));
}
}
代码解释
0.hash值的计算方式:
private int hash(String key){...}
这行代码的目的是将字节数组 digest 的前四个字节组合成一个 32 位的整数,用于作为哈希值。通过这种方式,可以将任意长度的字符串转换为一个固定的 32 位整数,适用于一致性哈希等场景。
其中公式:Math.abs(digest[0]) + ((digest[1] & 0xFF) << 8) + ((digest[2] & 0xFF) << 16) + ((digest[3] & 0xFF) << 24)的详细解释:
1.提取第一个字节并取绝对值: digest[0]:获取字节数组的第一个字节。 Math.abs(digest[0]):取第一个字节的绝对值。这是因为字节是有符号的,范围是 -128 到 127,取绝对值确保结果为正数。 2.提取第二个字节并左移 8 位: digest[1]:获取字节数组的第二个字节。 digest[1] & 0xFF:将第二个字节转换为无符号整数。0xFF 是一个 8 位的全 1 二进制数(即 255),与字节进行按位与操作,确保结果为 0 到 255 之间的无符号整数。 (digest[1] & 0xFF) << 8:将无符号整数左移 8 位,相当于乘以 256。 3.提取第三个字节并左移 16 位: digest[2]:获取字节数组的第三个字节。 digest[2] & 0xFF:将第三个字节转换为无符号整数。 (digest[2] & 0xFF) << 16:将无符号整数左移 16 位,相当于乘以 65536。 4.提取第四个字节并左移 24 位: digest[3]:获取字节数组的第四个字节。 digest[3] & 0xFF:将第四个字节转换为无符号整数。 (digest[3] & 0xFF) << 24:将无符号整数左移 24 位,相当于乘以 16777216。 5.将所有部分相加: Math.abs(digest[0]) + ((digest[1] & 0xFF) << 8) + ((digest[2] & 0xFF) << 16) + ((digest[3] & 0xFF) << 24):将上述四个部分相加,得到一个 32 位的整数。
1.环形表示:使用 SortedMap(TreeMap)来存储虚拟节点,以排序的方式存储,从而实现高效的检索。 2.哈希函数:hash 方法计算给定键的 SHA-256 哈希值并返回一个整数哈希值。
3.节点管理:
- 添加节点:addNode 方法为每个节点创建多个副本,并将它们添加到哈希环中。
- 移除节点:removeNode 方法移除与给定节点对应的虚拟节点。
4.键分配:getNode 方法通过计算键的哈希值,找到哈希环中离该键最近的虚拟节点,从而确定哪个节点负责该键
Python代码实现方式
import hashlib
from bisect import bisect
class ConsistentHashing:
def __init__(self, nodes=None, replicas=3):
self.replicas = replicas
self.ring = {}
self.sorted_keys = []
if nodes:
for node in nodes:
self.add_node(node)
def _hash(self, key):
return int(hashlib.sha256(key.encode()).hexdigest(), 16)
def add_node(self, node):
for i in range(self.replicas):
virtual_node = f"{node}#{i}"
self.ring[self._hash(virtual_node)] = node
self.sorted_keys.append(self._hash(virtual_node))
self.sorted_keys.sort()
def remove_node(self, node):
for i in range(self.replicas):
virtual_node = f"{node}#{i}"
del self.ring[self._hash(virtual_node)]
self.sorted_keys.remove(self._hash(virtual_node))
def get_node(self, key):
if not self.ring:
return None
hash_key = self._hash(key)
index = bisect(self.sorted_keys, hash_key) % len(self.sorted_keys)
return self.ring[self.sorted_keys[index]]
# Example usage
nodes = ['node1', 'node2', 'node3']
ch = ConsistentHashing(nodes)
print(ch.get_node('my_key')) # Get the node responsible for 'my_key'