python复现jedis中客户端分片访问


1 引言

在很多公司中是spark处理大量数据,然后塞入redis(如一堆特征数据),但是下游可能存在需要python读取该redis然后获取其中的数据(利用tf进行建模),但是jedis中的客户端分片机制在其他python的客户端sdk中都好像未复现,

2 python客户端调研

python连接redis sentinel集群中体现的,python在通过master_for等api访问时,第一个参数需要提供分片名称,这本无可厚非,可是如引言所述,我们有一堆分片,没法挨个塞进去,当然有人说,我可以写个python简单封装,假定有200个分片:

  • 先写个for循环获取这200个分片对应的master
  • 再写个多线程子代码,来一个key就同时分发,通过master.exists来判断,找到那个对应的redis实例
  • 再对该redis实例发起get和set

上述方法我试过,在第二步耗时30ms,第三步耗时0.3ms。所以如果一旦特征过多,该方法还是禁不起考验

3 jedis调研

Jedis - SharedJedisPool 初始化与应用 & hash 算法详解 中的流程可以简单知道jedis的流程,
我们通过下述代码截图来进一步知道其过程

通过定位,首先就是该类,可以通过初始化部分知道,传递的就是一个分片列表,而在后续的get中,是先获取对应的分片,然后基于该分片再去get。
继续定位:

key传进来后,先getShardInfo(key),通过代码执行

getKeyTag(key:String) // 结果还是key(意思是进去这个字符串,出来还是这个)
SafeEncoder.encode(getKeyTag(key:String)) // 结果还是key(意思是进去这个字符串,出来还是这个)
algo.hash(key)//的结果是 一个好大的整数,可能是"-3348278837805118665"这种
SortedMap<Long, S> tail = nodes.tailMap(algo.hash(key)) //是基于用户的key的hash值,将nodes表示的treemap中大于等于该值的取出来
tail.get(tail.firstKey())  // 是将tail中第一个取出来然后获取其对应的value

那么nodes是什么呢

如上图,nodes是一个treemap,其中shardInfo.getWeight()如果用户不特定初始化,其就是1,可以直接带入;

nodes的key就是分片名称(用户传的)+"*"+"1"+{n};其中n就是0到160;假定用户有200个分片,如"shard-{1-200}",那么这里就一共200*160个key;
其中的value就是shardinfo实例,通过shardinfo.getName()就能获取到实际的分片名称了。

4 代码实现

我们通过上述分析以及实际的值打印,进行了如下的复现,其中最重要的hash部分是参考了基于Python3.6实现Java版murmurhash算法,拿来即用。

这样我们

  • 先模仿jedis基于分片构建完整的nodes
  • 等用户输入key,再找到对应的那个分片的redis实例,
  • 再针对该实例发起get

下面代码的最下面有使用例子

#-*- coding: utf-8 -*-
import ctypes
import time
from pytreemap import TreeMap
import redis
from redis.sentinel import Sentinel


class RedisShared(object):
  def __init__(self,SentineAddress,shards,socket_timeout = 2000):
    self.socket_timeout = socket_timeout
    self.SentineAddress = SentineAddress

    self.SharedNames = shards

    self.MySentinel = Sentinel(self.SentineAddress,socket_timeout = self.socket_timeout)
    self._filter()
    self.resources = self._slaves()
    self.nodes = self._initialize()


  def _slaves(self):
    '''构建一个分片-》redis实例的映射 '''
    resources={}
    for i in self.shards:
      try:
        ans = self.MySentinel.slave_for(i,socket_timeout = self.socket_timeout)
        resources[i] = ans
      except Exception as e:
        print("=========",i,str(e))
        continue
    
    return resources


  def _filter(self):
    '''用来过滤不能访问的分片 '''
    self.shards = []
    for i in self.SharedNames:
      try:
        ans = self.MySentinel.discover_master(i)
        self.shards.append(i)
      except Exception as e:
        continue


  def _initialize(self):
    '''仿造jedis中进行nodes的初始化 '''
    nodes = TreeMap()
    for i,shard in enumerate(self.shards):
        for n in range(160):
            key = '{}*{}{}'.format(shard,1,n)
            nodes.put(ByteBuffer.mmh(key), shard)
    return nodes


  def get(self,key):
    '''具体的复现过程 '''
    hashkey =  ByteBuffer.mmh(key)

    total = self.nodes.tail_map(hashkey)

    shard = self.nodes.get( total.first_key())

    redisobj = self.resources[shard]

    res =  redisobj.get(key)
    return res
 

'''
# @File  : ByteBuffer.py
# @Author: Vam
# @Date  : 2020-09-08
# @Desc  : 实现java'的ByteBuffer对象有关mmh相关的算法
'''
class ByteBuffer(object):
    def __init__(self, buff:bytearray = None, position:int = 0, mark:int = -1, capacity:int = 0, limit:int = 0, order = "BIG_ENDIAN"):
        """
        :param buff:  buff即内部用于缓存的数组
        :param position: 当前读取的位置。
        :param mark: 为某一读过的位置做标记,便于某些时候回退到该位置。
        :param capacity: 初始化时候的容量。
        :param limit: 当写数据到buffer中时,limit一般和capacity相等,当读数据时,limit代表buffer中有效数据的长度。
        """
        self.buff = buff or bytearray()
        self.position = position
        self.capacity = capacity or len(buff)
        self.mark = 0
        self._limit(limit)
        self._position(position)
        self._order = order
        if mark >= 0:
            if mark > position:
                raise Exception("IllegalArgumentException:mark:%s, pos:%s" % (mark, position))
            self.mark = mark

    @classmethod
    def long_overflow(cls, val):
        maxint = 0x7fffffffffffffff
        if not -maxint - 1 <= val <= maxint:
            val = (val + (maxint + 1)) % (2 * (maxint + 1)) - maxint - 1
        return val

    @classmethod
    def unsigned_right_shitf(cls, n, i):
        # 对应Java >>>
        # 数字小于0,则转为64位无符号uint
        if n < 0:
            n = ctypes.c_uint64(n).value
        # 正常位移位数是为正数,但是为了兼容js之类的,负数就右移变成左移好了
        if i < 0:
            return -cls.long_overflow(n << abs(i))
        # print(n)
        return cls.long_overflow(n >> i)

    def _position(self, newPosition):
        if newPosition > self.limit or newPosition < 0:
            raise Exception("IllegalArgumentException")
        self.position = newPosition
        if self.mark > self.position:
            self.mark = -1
        return self

    def _limit(self, newLimit):
        if newLimit > self.capacity or newLimit < 0:
            raise Exception("IllegalArgumentException")
        self.limit = newLimit or self.capacity
        if self.position > self.limit:
            self.position = self.limit
        if self.mark > self.limit:
            self.mark = -1
        return self

    @classmethod
    def get_java_long(cls, num, unsignal=False):
        return cls.get_java_int(num, base=64, unsignal=unsignal)

    @classmethod
    def get_java_int(cls, num, base=32, unsignal=False):
        def trans(num_array):
            ret = ''
            for i in num_array:
                if i == '0':
                    ret += '1'
                else:
                    ret += '0'
            return ret
        if isinstance(num, bytes):
            num = ord(num)
        if unsignal:
            if num >= 2 ** base -1:
                transe_array = bin(num)[-base:]
                return int(transe_array, base=2)
        else:
            if abs(num) >= 2 ** (base - 1) - 1:
                bin_num = bin(num)
                transe_array = bin_num[-base:]
                signal = int(num / abs(num)) * ((-1) ** int(transe_array[0]))
                if int(transe_array[0]) == 1:  # 符号位为1时取反码+1
                    transe_array = trans(transe_array)
                    return signal * (int(transe_array, base=2) + 1 )
                return signal * (int(transe_array, base=2))
        return num

    @classmethod
    def allocate(cls, capacity=8):
        return ByteBuffer(capacity=capacity)

    @classmethod
    def wrap(cls, buff:bytearray, offset=0, length=None):
        if not offset and not length:
            return cls(buff=buff)
        else:
            return cls(buff=buff[offset:length])

    def remaining(self):
        return self.limit - self.position

    def get(self, index:int):
        return self.buff[index]

    def put(self, b):
        n = b.remaining()
        if n > self.remaining():
            raise Exception("BufferOverflowException")
        for i in range(b.position, b.position + n):
            self.buff.append(b.buff[i])
        return self

    def order(self, ByteOrder:str):
        self._order = ByteOrder
        return self

    def rewind(self):
        self._limit(self.capacity or len(self.buff))
        self._position(0)

    def nextGetIndex(self, nb):
        if self.limit - self.position < nb:
            raise Exception("BufferUnderflowException")
        p = self.position
        self.position += nb
        return p

    def getLong(self):
        if self._order == "BIG_ENDIAN":
            return self.getLongB(self.nextGetIndex(8))
        else:
            return self.getLongL(self.nextGetIndex(8))

    def _get(self, index):
        if index >= len(self.buff):
            return 0
        else:
            return self.buff[index]

    def getLongB(self, a):
        return self.makeLong(self._get(a),
                             self._get(a + 1),
                             self._get(a + 2),
                             self._get(a + 3),
                             self._get(a + 4),
                             self._get(a + 5),
                             self._get(a + 6),
                             self._get(a + 7))

    def getLongL(self, a):
        return self.makeLong(self._get(a + 7),
                             self._get(a + 6),
                             self._get(a + 5),
                             self._get(a + 4),
                             self._get(a + 3),
                             self._get(a + 2),
                             self._get(a + 1),
                             self._get(a))

    def makeLong(self, b7=0, b6=0, b5=0, b4=0, b3=0, b2=0, b1=0, b0=0):
        return ((ByteBuffer.get_java_long(b7) << 56) |
                ((ByteBuffer.get_java_long(b6) & 0xff) << 48) |
                ((ByteBuffer.get_java_long(b5) & 0xff) << 40) |
                ((ByteBuffer.get_java_long(b4) & 0xff) << 32) |
                ((ByteBuffer.get_java_long(b3) & 0xff) << 24) |
                ((ByteBuffer.get_java_long(b2) & 0xff) << 16) |
                ((ByteBuffer.get_java_long(b1) & 0xff) << 8) |
                ((ByteBuffer.get_java_long(b0) & 0xff)))

    @classmethod
    def mmh(cls, key):
        buf = cls.wrap(bytearray(key, encoding="utf-8"))
        seed = cls.get_java_int(0x1234ABCD)
        # print("seed:", seed)
        buf.rewind()
        order = buf._order
        buf.order("LITTLE_ENDIAN")
        m = cls.get_java_long(0xc6a4a7935bd1e995)
        # print("m:",m)
        r = 47
        # print(buf.__dict__)
        # print("remaining * m:", cls.get_java_long(buf.remaining() * m))
        h = cls.get_java_long(seed ^ (buf.remaining() * m))
        # print("h:", h)
        while buf.remaining() >= 8:
            k = cls.get_java_long(buf.getLong())
            k = cls.get_java_long(k * m)
            # print("k:", k)
            # print("k >>> r:", cls.unsigned_right_shitf(k, r))
            k = cls.get_java_long(k ^ (cls.unsigned_right_shitf(k, r)))
            k = cls.get_java_long( k * m)
            # print("k2:", k)

            h = cls.get_java_long(h ^ k)
            # print("h1:", h)
            # print("h * m", h * m)
            h = cls.get_java_long(h * m)
            # print("h", h)

        # print(buf.__dict__)
        # print(buf.remaining())
        if buf.remaining() > 0:
            finish = cls.allocate(8).order("LITTLE_ENDIAN")
            finish.put(buf).rewind()
            # print("finish:", finish.__dict__)
            finish.rewind()
            h = cls.get_java_long(h ^ finish.getLong())
            h = cls.get_java_long(h * m)

        h = cls.get_java_long(h ^ cls.unsigned_right_shitf(h, r))
        h = cls.get_java_long(h * m)
        h = cls.get_java_long(h ^ cls.unsigned_right_shitf(h, r))
        buf.order(order)
        return h

if __name__ == '__main__':
    SentineAddress = [('ip1', port1), ('ip2', port2), ('ip3', port3)]
    shards = [f'shard_{i}' for i in range(1,200) ]
    obj = RedisShared(SentineAddress,shards)
    k = "good-file-feature:123455" # 假设该key在分片shard_100上
    for i in range(10):
      st = time.time()
      ans = obj.get(k)
      print(ans,'====ind========',time.time()-st)  # 基本稳定在0.4ms

posted @ 2022-07-20 10:02  仙守  阅读(85)  评论(0编辑  收藏  举报