socketserver模块使用与源码分析
前言
在前面的学习中我们其实已经可以通过socket
模块来建立我们的服务端,并且还介绍了关于TCP协议的粘包问题。但是还有一个非常大的问题就是我们所编写的Server端是不支持并发性服务的,在我们之前的代码中只能加入一个通信循环来进行排队式的单窗口一对一服务。那么这一篇文章将主要介绍如何使用socketserver
模块来建立具有并发性的Server端。
我们先看它的一段代码,对照代码来看功能。
#!/usr/bin/env python3 # _*_ coding:utf-8 _*_ # ==== 使用socketserver创建支持多并发性的服务器 TCP协议 ==== import socketserver class MyServer(socketserver.BaseRequestHandler): """自定义类""" def handle(self): """handle处理请求""" print("双向链接通道建立完成:", self.request) # 对于TCP协议来说,self.request相当于双向链接通道conn,即accept()的第一部分 print("客户端的信息是:", self.client_address) # 对于TCP协议来说,相当于accept()的第二部分,即客户端的ip+port while 1: # 开始内层通信循环 try: # # bug修复:针对windows环境 data = self.request.recv(1024) if not data: break # bug修复:针对类UNIX环境 print("收到客户机[{0}]的消息:[{1}]".format(self.client_address, data)) self.request.sendall(data.upper()) # #sendall是重复调用send. except Exception as e: break self.request.close() # 当出现异常情况下一定要关闭链接 if __name__ == '__main__': s1 = socketserver.ThreadingTCPServer(("0.0.0.0", 6666), MyServer) # 公网服务器绑定 0.0.0.0 私网测试为 127.0.0.1 s1.serve_forever() # 启动服务
1.导入
socketserver
模块2.创建一个新的类,并继承
socketserver.BaseRequestHandler
类3.覆写
handle
方法,对于TCP协议来说,self.request
相当于双向链接通道conn
,self.client_address
相当于被服务方的ip和port信息,也就是addr
,而整个handle
方法相当于链接循环。4.写入收发逻辑规则
5.防止客户端发送空的消息已致双方卡死
6.防止客户端突然断开已致服务端崩溃
7.粘包优化(可选)
8.实例化
socketserver.ThreadingTCPServer
类,并传入IP+port,以及刚写好的类名9.使用
socketserver.ThreadingTCPServer
实例化对象中的server_forever( )
方法启动服务它其实是这样的:
我们不用管链接循环,因为在执行
handle
方法之前内部已经帮我们做好了。当我们使用serve_forever()
方法的时候便开始监听链接描述符对象,一旦有链接请求就创建一个子线程来处理该链接。
基于UDP协议的socketserver服务端
基于UDP协议的socketserver
服务端与基于TCP协议的socketserver
服务端大相径庭,但是还是有几点不太一样的地方。
对TCP来说:
self.request = 双向链接通道(conn)
对UDP来说:
self.request = (client_data_byte,udp的套接字对象)
#!/usr/bin/env python3 # _*_ coding:utf-8 _*_ # ==== 使用socketserver创建支持多并发性的服务器 UDP协议 ==== import socketserver class MyServer(socketserver.BaseRequestHandler): """自定义类""" def handle(self): """handle处理请求""" # 由于UDP是基于消息的协议,故根本不用通信循环 data = self.request[0] # 对于UDP协议来说,self.request其实是个元组。第一个元素是消息内容主题(Bytes类型),相当于recvfrom()的第一部分 server = self.request[1] # 第二个元素是服务端本身,即自己 print("客户端的信息是:", self.client_address) # 对于UDP协议来说,相当于recvfrom()的第二部分,即客户端的ip+port print("收到客户机[{0}]的消息:[{1}]".format(self.client_address, data)) server.sendto(data.upper(),self.client_address) if __name__ == '__main__': s1 = socketserver.ThreadingUDPServer(("0.0.0.0", 6666), MyServer) # 公网服务器绑定 0.0.0.0 私网测试为 127.0.0.1 s1.serve_forever() # 启动服务
扩展:socketserver源码分析
探索socketserver中的继承关系
好了,接下来我们开始剖析socketserver
模块中的源码部分。在Pycharm下使用CTRL+鼠标左键
,可以进入源码进行查看。
我们在查看源码前一定要首先要明白两点:
socketserver
类分为两部分,其一是server
类主要是负责处理链接方面,另一类是request
类主要负责处理通信方面。
好了,请在脑子里记住这个概念。我们来看一些socketserver
模块的实现用了哪些其他的基础模块。
注意,接下来的源码注释部分我并没有在源代码中修改,也请读者不要修改源代码的任何内容。
import socket # 这模块挺熟悉吧 import selectors # 这个是一个多线程模块,主要支持I/O多路复用。 import os # 老朋友了 import sys # 老朋友 import threading # 多线程模块 from io import BufferedIOBase # 读写相关的模块 from time import monotonic as time # 老朋友time模块
好了,让我们接着往下走。可以看到一个变量__all__
,是不是觉得很熟悉?就是我们使用 from xxx import xxx
能导入进的东西全是被__all__
控制的,我们看一下它包含了哪些内容。
__all__ = ["BaseServer", "TCPServer", "UDPServer", "ThreadingUDPServer", "ThreadingTCPServer", "BaseRequestHandler", "StreamRequestHandler", "DatagramRequestHandler", "ThreadingMixIn"] # 这个是我们原本的 __all__ 中的值。 if hasattr(os, "fork"): __all__.extend(["ForkingUDPServer","ForkingTCPServer", "ForkingMixIn"]) if hasattr(socket, "AF_UNIX"): __all__.extend(["UnixStreamServer","UnixDatagramServer", "ThreadingUnixStreamServer", "ThreadingUnixDatagramServer"]) # 上面两个if判断是给__all__添加内容的,os.fork()这个方法是创建一个新的进程,并且只在类UNIX平台下才有效,Windows平台下是无效的,所以这里对于Windows平台来说就from socketserver import xxx 肯定少了三个类,这三个类的作用我们接下来会聊到。而关于socket中的AF_UNIX来说我们其实已经学习过了,是基于文件的socket家族。这在Windows上也是不支持的,只有在类UNIX平台下才有效。所以Windows平台下的导入又少了4个类。 # poll/select have the advantage of not requiring any extra file descriptor, # contrarily to epoll/kqueue (also, they require a single syscall). if hasattr(selectors, 'PollSelector'): _ServerSelector = selectors.PollSelector else: _ServerSelector = selectors.SelectSelector # 这两个if还是做I/O多路复用使用的,Windows平台下的结果是False,而类Unix平台下的该if结果为True,这关乎I/O多路复用的性能选择。到底是select还是poll或者epoll。
我们接着向下看源码,会看到许许多多的类。先关掉它来假设自己是解释器一行一行往下走会去执行那个部分。首先是一条if
判断
if hasattr(os, "fork"): class ForkingMixIn: pass # 这里我自己省略了 # 我们可以看见这条代码是接下来执行的,它意思还是如果在类Unix环境下,则会去创建该类。如果在Windows平台下则不会创建该类
继续走,其实这种if
判断再创建类的地方还有两处。我这里全部列出来:
if hasattr(os, "fork"): class ForkingUDPServer(ForkingMixIn, UDPServer): pass class ForkingTCPServer(ForkingMixIn, TCPServer): pass if hasattr(socket, 'AF_UNIX'): class UnixStreamServer(TCPServer): address_family = socket.AF_UNIX class UnixDatagramServer(UDPServer): address_family = socket.AF_UNIX class ThreadingUnixStreamServer(ThreadingMixIn, UnixStreamServer): passclass ThreadingUnixDatagramServer(ThreadingMixIn, UnixDatagramServer): pass
好了,说完了大体粗略的一个流程,我们该来研究这里面的类都有什么作用,这里可以查看每个类的文档信息。大致如下:
前面已经说过,socketserver
模块中主要分为两大类,我们就依照这个来进行划分。
socketserver模块源码内部class功能一览 | |
---|---|
处理链接相关 | |
BaseServer | 基础链接类 |
TCPServer | TCP协议类 |
UDPServer | UDP协议类 |
UnixStreamServer | 文件形式字节流类 |
UnixDatagramServer | 文件形式数据报类 |
处理通信相关 | |
BaseRequestHandler | 基础请求处理类 |
StreamRequestHandler | 字节流请求处理类 |
DatagramRequestHandler | 数据报请求处理类 |
多线程相关 | |
ThreadingMixIn | 线程方式 |
ThreadingUDPServer | 多线程UDP协议服务类 |
ThreadingTCPServer | 多线程TCP协议服务类 |
多进程相关 | |
ForkingMixIn | 进程方式 |
ForkingUDPServer | 多进程UDP协议服务类 |
ForkingTCPServer | 多进程TCP协议服务类 |
他们的继承关系如下:
ForkingUDPServer(ForkingMixIn, UDPServer)
ForkingTCPServer(ForkingMixIn, TCPServer)
ThreadingUDPServer(ThreadingMixIn, UDPServer)
ThreadingTCPServer(ThreadingMixIn, TCPServer)
StreamRequestHandler(BaseRequestHandler)
DatagramRequestHandler(BaseRequestHandler)
处理链接相关
处理通信相关
多线程相关
总继承关系(处理通信相关的不在其中,并且不包含多进程)
最后补上一个多进程的继承关系,就不放在总继承关系中了,容易图形造成混乱。
多进程相关
实例化过程分析
有了继承关系我们可以来模拟实例化的过程,我们以TCP协议为准:
socketserver.ThreadingTCPServer(("0.0.0.0", 6666), MyServer)
我们点进(选中上面代码的ThradingTCPServe
r部分,CTRL+鼠标左键
)源码部分,查找其 __init__
方法:
class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass
看来没有,那么就找第一父类有没有,我们点进去可以看到第一父类ThreadingMixIn
也没有__init__
方法,看上面的继承关系图可以看出是普通多继承,那么就是广度优先的查找顺序。我们来看第二父类TCPServer
中有没有,看来第二父类中是有__init__
方法的,我们详细来看。
class TCPServer(BaseServer): """注释全被我删了,影响视线""" address_family = socket.AF_INET # 基于网络的套接字家族 socket_type = socket.SOCK_STREAM # TCP(字节流)协议 request_queue_size = 5 # 消息队列最大为5,可以理解为backlog,即半链接池的大小 allow_reuse_address = False # 端口重用默认关闭 def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True): """Constructor. May be extended, do not override.""" BaseServer.__init__(self, server_address, RequestHandlerClass) self.socket = socket.socket(self.address_family, self.socket_type) # 可以看见,上面先是调用了父类的__init__方法,然后又实例化出了一个socket对象!所以我们先不着急往下看,先看其父类中的__init__方法。 if bind_and_activate: try: self.server_bind() self.server_activate() except: self.server_close() raise
来看一下,BaseServer
类中的__init__
方法。
class BaseServer: """注释依旧全被我删了""" timeout = None # 这个变量可以理解为超时时间,先不着急说他。先看 __init__ 方法 def __init__(self, server_address, RequestHandlerClass): """Constructor. May be extended, do not override.""" self.server_address = server_address # 即我们传入的 ip+port ("0.0.0.0", 6666) self.RequestHandlerClass = RequestHandlerClass # 即我们传入的自定义类 MyServer self.__is_shut_down = threading.Event() # 这里可以看到执行了该方法,这里先不详解,因为它是一个事件锁,所以不用管 self.__shutdown_request = False
在BaseServer
中执行了thrading
模块下的Event()
方法。我这里还是提一嘴这个方法是干嘛用的,它会去控制线程的启动顺序,这里实例化出的self.__is_shut_down
其实就是一把锁,没什么深究的,接下来的文章中我也会写到。我们继续往下看,现在是该回到TCPServer
的__init__
方法中来了。
class TCPServer(BaseServer): """注释全被我删了,影响视线""" address_family = socket.AF_INET # 基于网络的套接字家族 socket_type = socket.SOCK_STREAM # TCP(字节流)协议 request_queue_size = 5 # 消息队列最大为5,可以理解为backlog,即半链接池的大小 allow_reuse_address = False # 端口重用默认关闭 def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True): # 看这里!!!! """Constructor. May be extended, do not override.""" BaseServer.__init__(self, server_address, RequestHandlerClass) self.socket = socket.socket(self.address_family, self.socket_type) if bind_and_activate: # 在创建完socket对象后就会进行该判断。默认参数bind_and_activate就是为True try: self.server_bind() # 现在进入该方法查看细节 self.server_activate() except: self.server_close() raise
好了,需要找这个self.bind()
方法,还是从头开始找。实例本身没有,第一父类ThreadingMixIn
也没有,所以现在我们看的是TCPServer
的server_bind()
方法:
def server_bind(self): """Called by constructor to bind the socket. May be overridden. """ if self.allow_reuse_address: # 这里的变量对应 TCPServer.__init__ 上面定义的类方法,端口重用这个。由于是False,所以我们直接往下执行。 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind(self.server_address) # 绑定 ip+port 即 ("0.0.0.0", 6666) self.server_address = self.socket.getsockname() # 获取socket的名字 其实还是 ("0.0.0.0", 6666)
现在我们该看TCPServer
下的server_activate()
方法了。
def server_activate(self): """Called by constructor to activate the server. May be overridden. """ self.socket.listen(self.request_queue_size) # 其实就是监听半链接池,backlog为5
这个时候没有任何异常会抛出的,所以我们已经跑完了整个实例化的流程。并将其赋值给s1
现在我们看一下s1
的__dict__
字典,再接着进行源码分析。
{'server_address': ('0.0.0.0', 6666), 'RequestHandlerClass': <class '__main__.MyServer'>, '_BaseServer__is_shut_down': <threading.Event object at 0x000002A96A0208E0>, '_BaseServer__shutdown_request': False, 'socket': <socket.socket fd=716, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('0.0.0.0', 6666)>}
server_forever()启动服务分析
我们接着来看下一条代码。
s1.serve_forever()
还是老规矩,由于s1
是ThreadingTCPServer
类的实例对象,所以我们去一层层的找serve_forever()
,最后在BaseServer
类中找到了。
def serve_forever(self, poll_interval=0.5): """注释被我删了""" self.__is_shut_down.clear() # 上面说过了那个Event锁,控制子线程的启动顺序。这里的clear()代表清除,这个不是重点,往下看。 try: # XXX: Consider using another file descriptor or connecting to the # socket to wake this up instead of polling. Polling reduces our # responsiveness to a shutdown request and wastes cpu at all other # times. with _ServerSelector() as selector: selector.register(self, selectors.EVENT_READ)# 这里是设置了一个监听类型为读取事件。也就是说当有请求来的时候当前socket对象就会发生反应。 while not self.__shutdown_request: # 为False,会执行,注意!下面都是死循环了!!! ready = selector.select(poll_interval) # 设置最大监听时间为0.5s # bpo-35017: shutdown() called during select(), exit immediately. if self.__shutdown_request: # BaseServer类中的类方法,为False,所以不执行这个。 break if ready: # 代表有链接请求会执行下面的方法 self._handle_request_noblock() # 这儿是比较重要的一个点。我们先来看。 self.service_actions() finally: self.__shutdown_request = False self.__is_shut_down.set() # 这里是一个释放锁的行为
如果有链接请求,则会执行self._handle_request_noblock()
方法,它在哪里呢?刚好这个方法就在BaseServer
中serve_forever()
方法的正下方第4个方法的位置。
def _handle_request_noblock(self): """注释被我删了""" try: request, client_address = self.get_request() # 这里的这个方法在TCPServer中,它的return值是 self.socket.accept(),就是就是返回了元组然后被解压赋值了。其实到这一步三次握手监听已经开启了。 except OSError: return if self.verify_request(request, client_address): # 这个是验证ip和port,返回的始终是True try: self.process_request(request, client_address) # request 双向链接通道,client_address客户端ip+port。现在我们来找这个方法。 except Exception: self.handle_error(request, client_address) self.shutdown_request(request) except: self.shutdown_request(request) raise else: self.shutdown_request(request)
现在开始查找self.process_request(request, client_address)
该方法,还是先从实例对象本身找,找不到去第一父类找。他位于第一父类ThreadingMixIn
中。
def process_request(self, request, client_address): """Start a new thread to process the request.""" t = threading.Thread(target = self.process_request_thread, args = (request, client_address)) # 创建子线程!!看这里! t.daemon = self.daemon_threads # ThreadingMixIn的类属性,为False if not t.daemon and self.block_on_close: # 第一个值为False,第二个值为True。他们都是ThreadingMixIn的类属性 if self._threads is None: # 会执行 self._threads = [] # 创建了空列表 self._threads.append(t) # 将当前的子线程添加至空列表中 t.start() # 开始当前子线程的运行,即运行self.process_request_thread方法
我们可以看到,这里的target
参数中指定了一个方法self.process_request_thread
,其实意思就是说当这个线程t
在start
的时候会去执行该方法。我们看一下它都做了什么,这个方法还是在ThreadingMixIn
类中。
def process_request_thread(self, request, client_address): """Same as in BaseServer but as a thread. In addition, exception handling is done here. """ try: self.finish_request(request, client_address) # 可以看到又执行该方法了,这里我再标注一下,别弄头晕了。request 双向链接通道,client_address客户端ip+port。 except Exception: self.handle_error(request, client_address) finally: self.shutdown_request(request) # 它不会关闭这个线程,而是将其设置为wait()状态。
看self.finish_request()
方法,它在BaseServer
类中
def finish_request(self, request, client_address): """Finish one request by instantiating RequestHandlerClass.""" self.RequestHandlerClass(request, client_address, self) # 这里是干嘛?其实就是在进行实例化!
self.RequestHandlerClass(request, client_address, self)
,我们找到self
的__dict__
字典,看看这个到底是什么东西
{'server_address': ('0.0.0.0', 6666), 'RequestHandlerClass': <class '__main__.MyServer'>, '_BaseServer__is_shut_down': <threading.Event object at 0x000002A96A0208E0>, '_BaseServer__shutdown_request': False, 'socket': <socket.socket fd=716, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('0.0.0.0', 6666)>}
可以看到,它就是我们传入的那个类,即自定义的MyServer
类。我们把request,client_address,以及整个是实例self传给了MyServer的__init__
方法。但是我们的MyServer类没有__init__
,怎么办呢?去它父类BaseRequestHandler
里面找呗。
class BaseRequestHandler: """注释被我删了""" def __init__(self, request, client_address, server): self.request = request # request 双向链接通道 self.client_address = client_address # 客户端ip+port self.server = server # 即 实例对象本身。上面的__dict__就是它的__dict__ self.setup() # 钩子函数,我们可以自己写一个类然后继承`BaseRequestHandler`并覆写其setup方法即可。 try: self.handle() # 看,自动执行handle finally: self.finish() # 钩子函数 def setup(self): pass def handle(self): pass def finish(self): pass
现在我们知道了,为什么一定要覆写handle
方法了吧。
socketserver内部调用顺序流程图(基于TCP协议)
实例化过程图解
server_forever()启动服务图解
扩展:验证链接合法性
在很多时候,我们的TCP服务端为了防止网络泛洪可以设置一个三次握手验证机制。那么这个验证机制的实现其实也是非常简单的,我们的思路在于进入通信循环之前,客户端和服务端先走一次链接认证,只有通过认证的客户端才能够继续和服务端进行链接。
下面就来看一下具体的实现步骤。
#_*_coding:utf-8_*_ __author__ = 'Linhaifeng' from socket import * import hmac,os secret_key=b'linhaifeng bang bang bang' def conn_auth(conn): ''' 认证客户端链接 :param conn: :return: ''' print('开始验证新链接的合法性') msg=os.urandom(32) # 新方法,生成32位随机Bytes类型的值 conn.sendall(msg) h=hmac.new(secret_key,msg) digest=h.digest() respone=conn.recv(len(digest)) return hmac.compare_digest(respone,digest) # 对比结果为True或者为False def data_handler(conn,bufsize=1024): if not conn_auth(conn): print('该链接不合法,关闭') conn.close() return print('链接合法,开始通信') while True: data=conn.recv(bufsize) if not data:break conn.sendall(data.upper()) def server_handler(ip_port,bufsize,backlog=5): ''' 只处理链接 :param ip_port: :return: ''' tcp_socket_server=socket(AF_INET,SOCK_STREAM) tcp_socket_server.bind(ip_port) tcp_socket_server.listen(backlog) while True: conn,addr=tcp_socket_server.accept() print('新连接[%s:%s]' %(addr[0],addr[1])) data_handler(conn,bufsize) if __name__ == '__main__': ip_port=('127.0.0.1',9999) bufsize=1024 server_handler(ip_port,bufsize)
#_*_coding:utf-8_*_ __author__ = 'Linhaifeng' from socket import * import hmac,os secret_key=b'linhaifeng bang bang bang' def conn_auth(conn): ''' 验证客户端到服务器的链接 :param conn: :return: ''' msg=conn.recv(32) # 拿到随机位数 h=hmac.new(secret_key,msg) # 掺盐 digest=h.digest() conn.sendall(digest) def client_handler(ip_port,bufsize=1024): tcp_socket_client=socket(AF_INET,SOCK_STREAM) tcp_socket_client.connect(ip_port) conn_auth(tcp_socket_client) while True: data=input('>>: ').strip() if not data:continue if data == 'quit':break tcp_socket_client.sendall(data.encode('utf-8')) respone=tcp_socket_client.recv(bufsize) print(respone.decode('utf-8')) tcp_socket_client.close() if __name__ == '__main__': ip_port=('127.0.0.1',9999) bufsize=1024 client_handler(ip_port,bufsize)
到这里已经很简单了,服务器将随机数给客户机发过去,客户机收到后也用自家的盐与随机数加料,再使用digest()
将它转化为字节,直接发送了回来然后客户端通过hmac.compare_digest()