python版本的mysql text resultset row协议代码实现
import struct,sys
from socket import *
from contextlib import closing
import hashlib,os
from functools import partial
from prettytable import PrettyTable #[liuzhuan] 引入表格
sha1_new = partial(hashlib.new, 'sha1')
SHA1_HASH_SIZE = 20
MULTI_RESULTS = 1 << 17
SECURE_CONNECTION = 1 << 15
CLIENT_PLUGIN_AUTH = 1 << 19
CLIENT_CONNECT_ATTRS = 1<< 20
CLIENT_PROTOCOL_41 = 1 << 9
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 1<<21
CLIENT_DEPRECATE_EOF = 1 << 24
LONG_PASSWORD = 1
LONG_FLAG = 1 << 2
PROTOCOL_41 = 1 << 9
TRANSACTIONS = 1 << 13
CAPABILITIES = (
LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS
| SECURE_CONNECTION | MULTI_RESULTS
| CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_ATTRS | CLIENT_DEPRECATE_EOF)
CLIENT_CONNECT_WITH_DB = 9
max_packet_size = 2 ** 24 - 1
charset_id = 45
class PreparPacket(object):
#类内函数实现
def __init__(self):
pass
#类内函数实现
def __null_bitmap(self,num_params):
_bytes = int((num_params + 7) / 8)
if _bytes == 1:
return bytearray(struct.pack('B',0))
elif _bytes == 2:
return bytearray(struct.pack('H', 0))
elif _bytes == 3:
return bytearray(struct.pack('HB', 0,0))
elif _bytes == 4:
return bytearray(struct.pack('I', 0))
#类内函数实现
def is_null(self,null_bytes,pos):
bit = null_bytes[int(pos / 8)]
if type(bit) is str:
bit = ord(bit)
return bit & (1 << (pos % 8))
#类内函数实现
def COM_Query(self, sql):
return struct.pack('B', 3) + sql.encode('utf8')
#类内函数实现
def COM_STMT_PREPARE(self,sql):
return struct.pack('B',0x16) + sql.encode('utf8')
#类内函数实现
def COM_STMT_EXECUTE(self,statement_id,flags,num_params,values,column_info):
_pack = struct.pack('<BIBI',0x17,statement_id,flags,0x01)
if num_params > 0:
_null_map = self.__null_bitmap(num_params)
for i,k in enumerate(values):
if k == None:
bytes_pos = int(i / 8)
bit_pos = int(i % 8)
_null_map[bytes_pos] |= 1 << bit_pos
_pack += _null_map + struct.pack('B',1)
_v = b''
for col_name in values:
col_type = 0x0f # default string
for col in column_info:
if col['name'].decode() == col_name:
col_type = col['type']
_pack += struct.pack('H',col_type)
if col_type in (0xfd,0xfe,0x0f):
_v += struct.pack('B',len(values[col_name]))
_v += values[col_name].encode('utf8')
elif col_type == 0x01:
_v += struct.pack('B',values[col_name])
elif col_type == 0x02:
_v += struct.pack('<H',values[col_name])
elif col_type in (0x03,0x09):
_v += struct.pack('<I', values[col_name])
elif col_type == 0x08:
_v += struct.pack('<Q',values[col_name])
return _pack + _v
#类内函数实现
def Prepar_head(self,playload_length,seq_id):
return struct.pack('<I', playload_length)[:3] + struct.pack('!B', seq_id)
#类内函数实现
def handshakeresponsepacket(self,server_packet_info,user,password,database=None):
client_flag = 0
client_flag |= CAPABILITIES
if database:
client_flag |= CLIENT_CONNECT_WITH_DB
server_version = (server_packet_info['server_version']).decode()
if int(server_version.split('.', 1)[0]) >= 5:
client_flag |= MULTI_RESULTS
response_packet = struct.pack('<iIB23s',client_flag,max_packet_size,charset_id,b'')
response_packet += user.encode() + b'\0'
sha1_password = self.sha1_password(password=password,auth_plugin_data=server_packet_info['auth_plugin_data'])
if server_packet_info['capability_flags'] & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA:
response_packet += struct.pack('!B',len(sha1_password)) + sha1_password
elif server_packet_info['capability_flags'] & SECURE_CONNECTION:
response_packet += struct.pack('B',len(sha1_password)) + sha1_password
else:
response_packet += sha1_password + b'\0'
if server_packet_info['capability_flags'] & CLIENT_CONNECT_WITH_DB:
if database:
response_packet += database.encode()
response_packet += b'\0'
if server_packet_info['capability_flags'] & CLIENT_PLUGIN_AUTH:
response_packet += b'' + b'\0'
if server_packet_info['capability_flags'] & CLIENT_CONNECT_ATTRS:
_connect_attrs = {
'_client_name': 'pymysql',
'_pid': str(os.getpid()),
'_client_version': '3.6.5',
'program_name' : sys.argv[0]
}
connect_attrs = b''
for k, v in _connect_attrs.items():
k = k.encode('utf8')
connect_attrs += struct.pack('B', len(k)) + k
v = v.encode('utf8')
connect_attrs += struct.pack('B', len(v)) + v
response_packet += struct.pack('B', len(connect_attrs)) + connect_attrs
return response_packet
#类内函数实现
def authswitchrequest(self,packet,offset,capability_flags,password):
end_pos = packet.find(b'\0', offset)
auth_name = packet[offset:end_pos].decode()
offset = end_pos + 1
auth_plugin_data = packet[offset:]
if capability_flags & CLIENT_PLUGIN_AUTH and auth_name:
data = self.sha1_password(password,auth_plugin_data)
return data
#类内函数实现
def sha1_password(self,password,auth_plugin_data):
_pass1 = sha1_new(password.encode()).digest()
_pass2 = sha1_new(_pass1).digest()
s = sha1_new()
s.update(auth_plugin_data[:SHA1_HASH_SIZE])
s.update(_pass2)
t = bytearray(s.digest())
for i in range(len(t)):
t[i] ^= _pass1[i]
return t
class UnpackPacket(PreparPacket):
#类内函数实现
def __init__(self):
super(UnpackPacket,self).__init__()
#类内函数实现
def unpack_handshake(self,packet,offset):
PLUGIN_AUTH = 1 << 19
server_packet_info = {}
#数据包内容
server_packet_info['packet_header'] = packet[offset]
offset += 1
_s_end = packet.find(b'\0', offset)
server_packet_info['server_version'] = packet[offset:_s_end]
offset = _s_end + 1
server_packet_info['thread_id'] = struct.unpack('<I',packet[offset:offset+4])
offset += 4
server_packet_info['auth_plugin_data'] = packet[offset:offset+8]
offset += 8 + 1
server_packet_info['capability_flags'] = struct.unpack('<H',packet[offset:offset+2])[0]
offset += 2
server_packet_info['character_set_id'],\
server_packet_info['status_flags'],\
capability_flags_2,auth_plugin_data_len = struct.unpack('<BHHB',packet[offset:offset+6])
server_packet_info['capability_flags'] |= capability_flags_2 << 16
offset += 6
offset += 10
auth_plugin_data_len = max(13,auth_plugin_data_len-8)
if len(packet) - 4 >= offset + auth_plugin_data_len:
# salt_len includes auth_plugin_data_part_1 and filler
server_packet_info['auth_plugin_data'] += packet[offset:offset + auth_plugin_data_len]
offset += auth_plugin_data_len
if server_packet_info['capability_flags'] & PLUGIN_AUTH and len(packet) - 4 >= offset:
_s_end = packet.find(b'\0',offset)
server_packet_info['auth_plugin_name'] = packet[offset:_s_end]
return server_packet_info
#类内函数实现
def unpack_text_values(self,packet,column_info,payload_length):
_offset = 0
_v = {}
_index = 0
while 1:
if _offset >= payload_length:
break
_l = packet[_offset] #[liuzhuan] _l表示当前数据数组内的偏移量所指的开始位置,其实就是具体的一行result row长度,这个长度都在每一行result row开头表达(并且还有其他含义,比如0xfb,表示null)
_offset += 1 #[liuzhuan] 偏移字节+1,跳过去长度位
if _l == 0xfb: #[liuzhuan] 0xfb在result row协议里表示null,然后其他所有数据都按照Protocol::LengthEncodedString类型处理(到了应用层再去转其他数据类型)
_v[column_info[_index]['name']] = None #[liuzhuan] 用_v这个字典里面再设置column_info[_index]['name'](其实就是columeDef包里的字段名)作为key,给这个key设置个value,构造字典,这里因为0xfb成立,所以value是null
else:
_v[column_info[_index]['name']] = packet[_offset:_offset+_l] #[liuzhuan] 构造字典方式同上,但value变成了(偏移位置~~偏移位置+row内下一列的长度)之间的内容
_offset += _l #[liuzhuan] 每一个row(_l)行都有可能不同,也有可能相同,具体的要看row字节数,这里将偏移挪到下一个row行开头(指向row内下一列长度),注意是下一列,不是下一行,下一行的跳在上层调用中进行
_index += 1 #[liuzhuan] _index这个下标用在column_info这个字典,在前一个栈内已经构造过了(多个columnDef包),+1表示指向下一个字典元素,也就是下一列,注意是下一列不是下一行
return _v
#类内函数实现
def unpack_binary_protocol(self,packet,cols_type):
offset = 0
_bytes = int((len(cols_type) + 7) / 8)
null_bytes = packet[:int((len(cols_type) + 7) / 8)]
offset += _bytes
values = []
for i in range(len(cols_type)):
if self.is_null(null_bytes,i):
values.append(None)
continue
if cols_type[i] in (0xfd, 0xfe, 0x0f,0xfc):
str_len = struct.pack('B', packet[offset])[0]
offset += 1
values.append(packet[offset:offset+str_len].decode('utf8','ignore'))
offset += str_len
elif cols_type[i] == 0x01:
values.append(struct.unpack('<B', packet[offset:offset + 1])[0])
offset += 1
elif cols_type[i] == 0x02:
values.append(struct.unpack('<H', packet[offset:offset + 2])[0])
offset += 2
elif cols_type[i] in (0x03, 0x09):
values.append(struct.unpack('<I', packet[offset:offset + 4])[0])
offset += 4
elif cols_type[i] == 0x08:
values.append(struct.unpack('<Q', packet[offset:offset+8])[0])
offset += 8
return values
#类内函数实现
def unpack_text_column(self,packet):
_dcit = {}
_offset = 0
_l = packet[_offset]
_offset += 1
_dcit['catalog'] = packet[_offset:_offset+_l]
_offset += _l
_l = packet[_offset]
_offset += 1
_dcit['schema'] = packet[_offset:_offset + _l]
_offset += _l
_l = packet[_offset]
_offset += 1
_dcit['table'] = packet[_offset:_offset + _l]
_offset += _l
_l = packet[_offset]
_offset += 1
_dcit['org_table'] = packet[_offset:_offset + _l]
_offset += _l
_l = packet[_offset]
_offset += 1
_dcit['name'] = packet[_offset:_offset + _l]
_offset += _l
_l = packet[_offset]
_offset += 1
_dcit['org_name'] = packet[_offset:_offset + _l]
_offset += _l
_offset += 1
_dcit['character_set'] = struct.unpack('H',packet[_offset:_offset+2])[0]
_offset += 2
_dcit['column_length'] = struct.unpack('I',packet[_offset:_offset+4])[0]
_offset += 4
_dcit['type'] = struct.unpack('B',packet[_offset:_offset+1])[0]
_offset += 1
_dcit['flag'] = struct.unpack('H',packet[_offset:_offset+2])[0]
return _dcit
#类内函数实现
def com_prepare_ok(self,packet):
offset = 1
statement_id = struct.unpack('I',packet[offset:offset+4])[0]
offset += 4
num_columns = struct.unpack('H',packet[offset:offset+2])[0]
offset += 2
num_params = struct.unpack('H',packet[offset:offset+2])[0]
offset += 2
return statement_id,num_columns,num_params
class TcpClient(UnpackPacket):
#类内函数实现
def __init__(self,host_content,user_name,password,databases,sql=None,values=None,type=None):
super(TcpClient,self).__init__()
_host_content = host_content.split(':')
self.sql = sql
self.pre_values = values
self.type = type
self.user = user_name
self.password = password
self.database = databases
HOST = _host_content[0]
PORT = int(_host_content[1])
self.BUFSIZ = 1024
self.ADDR = (HOST, PORT)
self.client=socket(AF_INET, SOCK_STREAM)
self.client.connect(self.ADDR)
self.client.settimeout(0.1)
self.server_packet_info = {}
self.packet = None
self.column_info = []
self.values = []
#类内函数实现
def header(self,offset=None):
self.offset = offset if offset else 0
self.payload_length = self.packet[self.offset+2] << 16 | self.packet[self.offset+1] << 8 | self.packet[self.offset]
self.seq_id = self.packet[self.offset+3]
self.offset += 4 #[liuzhuan] mysql的任何包头都是4个字节,所以每次偏移都是+4
#类内函数实现
def check_packet(self):
packet_header = self.packet[self.offset]
self.offset += 1
if packet_header == 0x00:
print('connection ok')
self.__command_prepare()
elif packet_header in (0xfe,0xff): # [liuzhuan] 判断服务器响应报文非OK,EOF,ERROR(0x00, 0xFE,0xFF)
print(self.packet[self.offset:]) # [liuzhuan] :符号表示输出除了self.offset位置之外,其余所有的数组内容
#类内函数实现
def Send(self):
self.__recv_data()
self.server_packet_info = self.unpack_handshake(packet=self.packet,offset=self.offset)
self.response_packet = self.handshakeresponsepacket(server_packet_info=self.server_packet_info,
user=self.user,password=self.password,
database=self.database)
response_payload = len(self.response_packet)
self.client.send(self.Prepar_head(response_payload,self.seq_id + 1) + self.response_packet)
self.__recv_data()
packet_header = self.packet[self.offset]
self.offset += 1
if packet_header == 0xff:
error_code = struct.unpack('<H', self.packet[self.offset:self.offset + 2])
self.offset+= 2
print(error_code,self.packet[self.offset:])
elif packet_header == 0xfe:
"""AuthSwitchRequest"""
if len(self.packet) < 9: #[liuzhuan] 判定是否为eof包,这是4.1协议后固定的判定方式,4.1之前不需要判定长度,只需要看标志位即可
print('this is eof packet')
else:
_data = self.authswitchrequest(packet=self.packet,offset=self.offset,password=self.password,
capability_flags=self.server_packet_info['capability_flags'])
self.client.send(struct.pack('<I', len(_data))[:3] + struct.pack('!B', 3) + _data)
self.__recv_data()
self.check_packet()
elif packet_header == 0x00:
if len(self.packet) > 7: #[liuzhuan] 判定是否为ok包,这是固定的判定方式,这是4.1协议后固定的判定方式,4.1之前不需要判定长度,只需要看标志位即可
print('ok packet')
self.__command_prepare()
#类内函数实现
def __unpack_text_packet(self):
#[liuzhuan] 这是Text Resultset Row的第一个包,column_count包,包头在上一个栈的__recv中调用header已经做了偏移了,这里直接是column_count包体,
#从第四个字节,取一个字节,表示payload长度,表达式是先取self.offset + self.payload_length
#再做self.offset:self.offset,pdb调试出来的值是[4:5],取一个字节,这里类型是Protocol::LengthEncodedInteger,其实就是Protocol::ColumnCount包
#接下来是Protocol::ColumnDefinition包,会有多个包,他后面就是result row包,一个表有多少列,就有多少个ColumnDefinition包,每个包内描述了列的具体信息
#第一个包由于是column_count包,所以这里的payload_length在开头字节数都是1,后面会根据情况累加上去的
'''
column_count = struct.unpack('B',self.packet[self.offset:self.offset + self.payload_length])[0] #[liuzhuan] 构造Protocol::LengthEncodedInteger包
self.offset += self.payload_length #[liuzhuan] 偏移量跳过去column_count
print('[liuzhuan] 偏移位置跳过column_count包后为: ' + str(self.offset))
for i in range(column_count): #[liuzhuan] 遍历所有字段,构造Protocol::ColumnDefinition包
self.header(offset=self.offset) #[liuzhuan] 取ColumnDefinition包头,并偏移payload_length对应的字节位
self.column_info.append(self.unpack_text_column(self.packet[self.offset:self.offset+self.payload_length])) #[liuzhuan] 每个ColumnDefinition包在这里处理方法是偏移+包长,送到unpack_text_column函数构造一个python dict
self.offset += self.payload_length
print('[liuzhuan] 第{0}个ColumnDefinition包,偏移位置{1}'.format(i, self.offset))
#[liuzhuan] 调试偏移
print('[liuzhuan] 偏移位置跳过全部ColumnDefinition包后为: ' + str(self.offset))
#[liuzhuan] 这里开始解析result row包
while 1:
_v = []
self.header(offset=self.offset) #[liuzhuan] 偏移跳过去头
packet_header = self.packet[self.offset]
if packet_header in (0xfe, 0xff, 0x00): # [liuzhuan] 判断服务器响应报文非OK,EOF,ERROR(0x00, 0xFE,0xFF),因为在ColumnDefinition包整个结束后,都会跟着一个EOF包,这个EOF需要跳过去
break # [liuzhuan] 这里参考官网的 https://dev.mysql.com/doc/internals/en/com-query-response.html#column-definition,有个图写的很明白了
self.values.append(self.unpack_text_values(self.packet[self.offset:self.offset+self.payload_length], #[liuzhuan] packet数组取的包内容是偏移(包头)+包体,表达式就是payload_length+偏移-包头
self.column_info,self.payload_length)) #[liuzhuan] 然后把ColumdDef字典和包长度传过去,用于构造数据值的dict(values)
self.offset += self.payload_length #[liuzhuan] 这里特别重要,offser每次都会根据self.payload_length长度改变,而self.payload_length每次都会根据一个row包头内指出的payload长度而改变
if self.type:
pass
else:
print('[liuzhuan] 测试colume def包输出')
print(self.column_info)
print('[liuzhuan] 测试colume value包输出')
pi_title = [] # 标题
pi_value = [] # 值
for row in self.values:
if len(pi_title) <= 0:
for i in row:
pi_title.append(str(i, encoding='utf-8'))
pi_table = PrettyTable(pi_title) # 表格
for row in self.values:
pi_value.clear()
for i in row:
pi_value.append(str(row[i], encoding='utf-8'))
pi_table.add_row(pi_value)
print(pi_table)
#类内函数实现
def __prepared_statements(self):
_pre_packet = self.COM_STMT_PREPARE(sql=self.sql)
stmt_prepare_packet = self.Prepar_head(len(_pre_packet), self.next_seq_id) + _pre_packet
self.client.send(stmt_prepare_packet)
self.__recv_data(result=True)
flags = {'NO_CURSOR':0x00,'READ_ONLY':0x01,'FOR_UPDATE':0x02,'SCROLLABLE':0x04}
if self.packet[self.offset] == 0x00:
self.statement_id, num_columns, self.num_params = self.com_prepare_ok(self.packet[self.offset:self.offset+self.payload_length])
self.offset += self.payload_length
for i in range(num_columns + self.num_params):
self.header(self.offset)
if i >= self.num_params:
self.column_info.append(self.unpack_text_column(self.packet[self.offset:self.offset + self.payload_length]))
self.offset += self.payload_length
execute_pack = self.COM_STMT_EXECUTE(statement_id=self.statement_id,flags=flags['NO_CURSOR'],
num_params=self.num_params,values=self.pre_values,column_info=self.column_info)
execute_pack = self.Prepar_head(len(execute_pack),self.next_seq_id) + execute_pack
self.client.send(execute_pack)
self.__recv_data(result=True)
column_count = struct.unpack('B', self.packet[self.offset:self.offset + self.payload_length])[0]
self.offset += self.payload_length
for i in range(column_count):
self.header(self.offset)
#self.column_info.append(self.unpack_text_column(self.packet[self.offset:self.offset + self.payload_length]))
self.offset += self.payload_length
values = []
col_types = []
for col in self.column_info:
col_types.append(col['type'])
while 1:
self.header(self.offset)
_header = struct.unpack('B',self.packet[self.offset:self.offset+1])[0]
self.offset += 1
if _header == 0xfe:
break
elif _header == 0x00:
values.append(self.unpack_binary_protocol(self.packet[self.offset:self.offset+self.payload_length-1],col_types))
self.offset += self.payload_length -1
print('|'.join([col['name'].decode() for col in self.column_info]))
for row in values:
print(tuple(row))
#类内函数实现
def __command_prepare(self):
self.next_seq_id = 0
if self.type == 'pre':
self.__prepared_statements()
else:
_com_packet = self.COM_Query(self.sql) # [liuzhuan] 用struct.pack构造一个query类型mysql包
com_packet = self.Prepar_head(len(_com_packet),self.next_seq_id) + _com_packet # [liuzhuan] 构造query包头+包体,命令类的包比较简单,包头里指明payload长度后,直接加上包体就可以了
self.client.send(com_packet)
self.__recv_data(result=True) # [liuzhuan] 这里收到的是result row包,可以看到result为true(特别注意一点,这里调用了header构造了self.payload_length初始值)
self.__unpack_text_packet() # [liuzhuan] 按照result row协议解开包
#类内函数实现
def __recv_data(self,result=None):
_packet = b''
self.packet = b''
state = 0
while 1:
try:
_packet = self.client.recv(self.BUFSIZ)
self.packet += _packet
if result is None: #[liuzhuan] 在其他命令调用中,直接给result为none,取一次套接字结果后返回,在查询模式下(参考594行代码注释),result为true
break #走异常模式,在state累计三次后,退出套接字recv模式
state = 0
except:
state += 1
if state >=3:
break
self.header() # [liuzhuan] 这里注意一点,很容易被忽略,其实这一步的头构造非常重要,他这里其实把self.payload_length做了初始化了
#类内函数实现
def close(self):
self.client.close()
# prepare语句执行
#sql = 'select * from information_schema.tables where table_schema=?'
#values = {'table_schema':'information_schema'}
#with closing(TcpClient('192.168.10.12:3306','root','root','',sql,values,'pre')) as tcpclient:
# tcpclient.Send()
# 语句直接执行
# sql = 'select * from information_schema.tables'
# values = {'table_schema':'information_schema'}
# with closing(TcpClient('192.168.10.12:3306','root','root','',sql,values,'')) as tcpclient:
# tcpclient.Send()
# [liuzhuan] 我的测试
sql = 'select * from test.book'
values = {'test':'book'}
with closing(TcpClient('127.0.0.1:3306','root','123456','',sql,values,'')) as tcpclient:
tcpclient.Send()
图表为pdb跟踪的
(venv) liuzhuan@liuzhuan-ThinkPad-X250:~/wintrust23/mysql-protocol-python/wwwbjq$ python command_prepare.py
connection ok
[liuzhuan] 偏移位置跳过column_count包后为: 5
[liuzhuan] 第0个ColumnDefinition包,偏移位置57
[liuzhuan] 第1个ColumnDefinition包,偏移位置109
[liuzhuan] 第2个ColumnDefinition包,偏移位置157
[liuzhuan] 偏移位置跳过全部ColumnDefinition包后为: 157
[liuzhuan] 测试colume def包输出
[{'catalog': b'def', 'schema': b'test', 'table': b'book', 'org_table': b'book', 'name': b'publish', 'org_name': b'publish', 'character_set': 45, 'column_length': 128, 'type': 253, 'flag': 0}, {'catalog': b'def', 'schema': b'test', 'table': b'book', 'org_table': b'book', 'name': b'address', 'org_name': b'address', 'character_set': 45, 'column_length': 128, 'type': 253, 'flag': 0}, {'catalog': b'def', 'schema': b'test', 'table': b'book', 'org_table': b'book', 'name': b'phone', 'org_name': b'phone', 'character_set': 45, 'column_length': 128, 'type': 253, 'flag': 0}]
[liuzhuan] 测试colume value包输出
+---------------------------------+----------------------------------+-------------+
| publish | address | phone |
+---------------------------------+----------------------------------+-------------+
| mysql优化之路 | 中国北京市区通州区三大街五号胡同 | 13323376667 |
| Adobe Photoshop CC 2017经典教程 | 天津市和平区 | 13278728288 |
| MariaDB原理与实现 | 西场火箭发射基地 | 13211190888 |
| 深入理解MySQL核心技术 | 乃海西沙是的 | 13211190666 |
+---------------------------------+----------------------------------+-------------+