Day7—socket进阶

本节内容 :

  1. Socket实现多连接处理
  2. Socket实现简单的ssh
  3. SocketServer实现真正的并发

 

1.Socket实现多连接处理

根据上一篇文章多socket的基本介绍,我们大致能明白怎么去建立一个客户端和服务器通讯,但是你会发现,如果客户端断开,服务器也会跟着短了,如果我们想要客户端断开,服务器还可以为下一个客户端服务,该如何实现呢?

1 conn,addr = server.accept() #接收并建立与客户端的链接,一直阻塞直到客户端连进来

我们想要实现多连接处理,上面这句代码很关键。只要我们同一个客户端断开连接后,它还可以循环监听,那么多连接就实现了。

 1 import socket
 2 
 3 server = socket.socket() #获取socket实例
 4 
 5 server.bind(('localhost',9999)) #绑定 ip + port
 6 
 7 server.listen() # 开始监听
 8 
 9 while True: #第一层循环用来持续接收新链接
10     print('waiting for new conn...')
11     conn,addr = server.accept()     #等待新链接
12     print('new conn:',addr)
13     while True: #第二层循环用来保持和连到的链接持续通讯
14         data = conn.recv(1024)
15         if not data:
16             print('client is off...')
17             break     #断开已建立的通讯,回到等新链接的循环
18          print('recv data:',data)
19         conn.send(data.upper())
20 server.close()
多连接-服务端

注意了, 此时服务器端依然只能同时为一个客户服务,其客户来了,得排队(连接挂起),要想同时处理多个链接,请继续往后看。

 

2. Socket实现简单的ssh

出了可以发送一些简单消息,我们还可以干点更高级的,比如来做一个简单的ssh(不明白的可以百度这是干嘛的!),它就是通过客户端连上服务器后,让服务器执行命令,并把结果返回给客户端。

 1 import os
 2 import socket
 3 
 4 server = socket.socket()
 5 server.bind(('localhost',7777))
 6 server.listen()
 7 
 8 while True:
 9     print("等待接收。。")
10     conn,addr = server.accept()
11     print("接收到客户端:",addr)
12     while True:
13         cmd = conn.recv(1024).decode()  #接收客户端发送的命令
14         data = os.popen(cmd).read()  #用poen内置函数执行指令,并读取结果
15         print(data)
16         if not data:  #如果没有内容,说明没有此指令
17             data = "mei you ci ming ling"
18         conn.send(data.encode("utf-8")) #将执行的结果返回给客户端
19         break
20 
21 server.close()
ssh server
 1 import socket
 2 
 3 client = socket.socket()
 4 client.connect(('localhost',7777))
 5 while True:
 6 
 7     cmd = input(">>>")  #输入指令
 8     if len(cmd) == 0:      #用长短来判断用户是否输入为空
 9         continue
10     elif cmd == 'exit':    #判断是否为退出
11         print("退出!")
12         break
13     client.send(cmd.encode("utf-8"))  #发送指令需要编码成bytes类型
14     data = client.recv(6094).decode()
15     print(data)
ssh client

很好,我们已经实现了一个简单的ssh,但你如果尝试的指令够多,就会发现,程序还是有问题的

  1. 不能执行top等类似的 会持续输出的命令,这是因为,服务器端在收到客户端指令后,会一次性通过os.popen执行,并得到结果后返回给客户,但top这样的命令用os.popen执行你会发现永远都不会结束,所以客户端也永远拿不到返回。(真正的ssh是通过select 异步等模块实现的,我们以后会涉及)
  2. 不能执行像cd这种没有返回的指令, 因为客户端每发送一条指令,就会通过client.recv(1024)等待接收服务器端的返回结果,但是cd命令没有结果 ,服务器端调用conn.send(data)时是不会发送数据给客户端的。 所以客户端就会一直等着,等到天荒地老,结果就卡死了。解决的办法是,在服务器端判断命令的执行返回结果的长度,如果结果为空,就自己加个结果返回给客户端,如写上"cmd exec success, has no output."
  3. 如果执行的命令返回结果的数据量比较大,会发现,结果返回不全,在客户端上再执行一条命令,结果返回的还是上一条命令的后半段的执行结果,这是为什么呢?这是因为,我们的客户写client.recv(1024), 即客户端一次最多只接收1024个字节,如果服务器端返回的数据是2000字节,那有至少9百多字节是客户端第一次接收不了的,那怎么办呢,服务器端此时不能把数据直接扔了呀,so它会暂时存在服务器的io发送缓冲区里,等客户端下次再接收数据的时候再发送给客户端。 这就是为什么客户端执行第2条命令时,却接收到了第一条命令的结果的原因。 这时有同学说了, 那我直接在客户端把client.recv(1024)改大一点不就好了么, 改成一次接收个100mb,哈哈,这是不行的,因为socket每次接收和发送都有最大数据量限制的,毕竟网络带宽也是有限的呀,不能一次发太多,发送的数据最大量的限制 就是缓冲区能缓存的数据的最大量,这个缓冲区的最大值在不同的系统上是不一样的, 我实在查不到一个具体的数字,但测试的结果是,在linux上最大一次可接收10mb左右的数据,不过官方的建议是不超过8k,也就是8192,并且数据要可以被2整除,不要问为什么 。如果一次只能接收最多不超过8192的数据 ,那服务端返回的数据超过了这个数字怎么办呢?比如让服务器端打开一个5mb的文件并返回,客户端怎么才能完整的接受到呢?那就只能循环收取啦。

  在开始解决上面问题3之前,我们要考虑,客户端要循环接收服务器端的大量数据返回直到一条命令的结果全部返回为止, 但问题是客户端知道服务器端返回的数据有多大么?答案是不知道,那既然不知道服务器的要返回多大的数据,那客户端怎么知道要循环接收多少次呢?答案是不知道,擦,那咋办? 总不能靠猜吧?呵呵。。。 当然不能,那只能让服务器在发送数据之前主动告诉客户端,要发送多少数据给客户端,然后再开始发送数据,yes, 机智如我,搞起。

 1 import os
 2 import socket
 3 
 4 server = socket.socket()
 5 server.bind(('localhost',7777))
 6 server.listen()
 7 
 8 while True:
 9     print("等待接收。。")
10     conn,addr = server.accept()
11     print("接收到客户端:",addr)
12     while True:
13         cmd = conn.recv(1024).decode()
14         if cmd == 'exit':
15             print('client is off...')
16             break
17         data = os.popen(cmd).read()
18         print(data)
19         if not data:
20             data = "cmd exec success,has no output"
21         conn.send(str(len(data)).encode()) #数据发送前,先告诉客户端数据的大小
22         conn.sendall(data.encode("utf-8"))  #发送端也有最大数据量限制,所以这里用sendall,相当于重复循环调用conn.send,直至数据发送完毕
23 
24 server.close()
ssh server 发送大小
 1 import socket
 2 
 3 client = socket.socket()
 4 client.connect(('localhost',7777))
 5 while True:
 6 
 7     cmd = input(">>>")
 8     if len(cmd) == 0:
 9         continue
10     elif cmd == 'exit':
11         client.send(cmd.encode("utf-8"))
12         print("退出!")
13         break
14     client.send(cmd.encode("utf-8"))
15     res_return_size  = client.recv(1024) #接收这条命令执行结果的大小
16    #data = client.recv(6094).decode()
17     print(res_return_size)
ssh client 接受大小

如果你够幸运,会发现程序执行会报错,客户端本想只接服务器端命令的执行结果,但实际上却连命令结果也跟着接收了一部分。这又是为什么呢?

这里就引入了一个重要的概念,“粘包”, 即服务器端你调用时 ’send ’ 2次,但你send调用时,数据其实并没有立刻被发送给客户端,而是放到了系统的socket发送缓冲区里,等缓冲区满了、或者数据等待超时了,数据才会被send到客户端,这样就把好几次的小数据拼成一个大数据,统一发送到客户端了,这么做的目地是为了提高io利用效率,一次性发送总比连发好几次效率高嘛。 但也带来一个问题,就是“粘包”,即2次或多次的数据粘在了一起统一发送了。就是报错的情况 。

我们在这里必须要想办法把粘包分开, 因为不分开,我们就没办法取出来服务器端返回的命令执行结果的大小呀。so ,那怎么分开呢?首先我们是没办法让缓冲区强制刷新把数据发给客户端的。 能做的,只有一个。让缓冲区超时,超时了,系统就不会等缓冲区满了,会直接把数据发走,因为不能一个劲的等后面的数据呀,等太久,会造成数据延迟了,那可是极不好的。如何让缓冲区超时呢?

答案就是:

  1.  low-way:time.sleep(0.5)。经多次测试,让服务器程序sleep 至少0.5就会造成缓冲区超时。哈哈哈, 你会说,这么玩不会被老板开除么,虽然我们觉得0.5s不多,但是对数据实时要求高的业务场景,比如股票交易,过了0.5s 股票价格可以就涨跌很多。但没办法,我刚学socket的时候 找不到更好的办法,就是这么玩的,现在想想也真是low呀
  2. high-way: 服务器端每发送一个数据给客户端,就立刻等待客户端进行回应。即调用 conn.recv(1024), 由于recv在接收不到数据时是阻塞的,这样就会造成,服务器端接收不到客户端的响应,就不会执行后面的conn.sendall(命令结果)的指令,收到客户端响应后,再发送命令结果时,缓冲区就已经被清空了,因为上一次的数据已经被强制发到客户端了。 机智 吧, 看下面代码实现。
 1 import os
 2 import socket
 3 
 4 server = socket.socket()
 5 server.bind(('localhost',7777))
 6 server.listen()
 7 
 8 while True:
 9     print("等待接收。。")
10     conn,addr = server.accept()
11     print("接收到客户端:",addr)
12     while True:
13         cmd = conn.recv(1024).decode()
14         if cmd == 'exit':
15             print('client is off...')
16             break
17         data = os.popen(cmd).read()
18         print(data)
19         if not data:
20             data = "cmd exec success,has no output"
21         conn.send(str(len(data)).encode())
22         conn.recv(1024)  #接收客户端回执信息,不做处理,防粘包
23         conn.sendall(data.encode("utf-8"))
24 
25 server.close()
ssh server 不粘包
 1 import socket
 2 
 3 client = socket.socket()
 4 client.connect(('localhost',7777))
 5 while True:
 6 
 7     cmd = input(">>>")
 8     if len(cmd) == 0:
 9         continue
10     elif cmd == 'exit':
11         client.send(cmd.encode("utf-8"))
12         print("退出!")
13         break
14     client.send(cmd.encode("utf-8"))
15 
16     res_return_size  = client.recv(1024) #接收这条命令执行结果的大小
17     print("getting cmd result , ", res_return_size)
18     total_rece_size = int(res_return_size)
19     print("total size:",res_return_size)
20     client.send("准备好接收了,发吧loser".encode("utf-8"))  #发送回执信息
21     received_size = 0 #已接收到的数据
22     cmd_res = b''
23     f = open("test_copy.html","wb")#把接收到的结果存下来,一会看看收到的数据 对不对
24     while received_size != total_rece_size: #代表还没收完
25         data = client.recv(1024)
26         received_size += len(data) #为什么不是直接1024,还判断len干嘛,注意,实际收到的data有可能比1024少
27         cmd_res += data
28     else:
29         print("数据收完了",received_size)
30         #print(cmd_res.decode())
31         f.write(cmd_res) #把接收到的结果存下来,一会看看收到的数据 对不对
32     #print(data.decode()) #命令执行结果
33 
34 client.close()
ssh client 不粘包

 

3.SocketServer实现真正的并发

SocketServer内部使用 IO多路复用 以及 “多线程” 和 “多进程” ,从而实现并发处理多个客户端请求的Socket服务端。即:每个客户端请求连接到服务器时,Socket服务端都会在服务器是创建一个“线程”或者“进程” 专门负责处理当前客户端的所有请求。

ThreadingTCPServer

ThreadingTCPServer实现的Soket服务器内部会为每个client创建一个 “线程”,该线程用来和客户端进行交互。

1、ThreadingTCPServer基础

使用ThreadingTCPServer:

  • 创建一个继承自 SocketServer.BaseRequestHandler 的类
  • 类中必须定义一个名称为 handle 的方法
  • 启动ThreadingTCPServer
 1 import socketserver
 2 
 3 class MyTCPHandler(socketserver.BaseRequestHandler):
 4 
 5     """
 6     The request handler class for our server.
 7 
 8     It is instantiated once per connection to the server, and must
 9     override the handle() method to implement communication to the
10     client.
11     """
12     def handle(self):
13         # self.request is the TCP socket connected to the client
14         while True:
15             try:
16                 self.data = self.request.recv(1024).strip()
17                 print("{} wrote:".format(self.client_address[0]))
18                 print(self.data)
19                 # just send back the same data, but upper-cased
20                 self.request.sendall(self.data.upper())
21             except ConnectionResetError as e:
22                 print("error:",e)
23                 break
24 
25 if __name__ == "__main__":
26     HOST, PORT = "localhost", 9999
27 
28     # Create the server, binding to localhost on port 9999
29     server = socketserver.ThreadingTCPServer((HOST, PORT), MyTCPHandler)
30 
31     # Activate the server; this will keep running until you
32     # interrupt the program with Ctrl-C
33     server.serve_forever()
SocketServer服务器
 1 import socket
 2 #客户端没什么特别,跟之前都一样
 3 client = socket.socket()
 4 client.connect(("localhost",9999))
 5 
 6 while True:
 7     data = input(">>>").strip()
 8     client.send(data.encode("utf-8"))
 9     recv_data = client.recv(1024).decode()
10     print(recv_data)
客户端

2、ThreadingTCPServer源码剖析

ThreadingTCPServer的类图关系如下:

内部调用流程为:

  • 启动服务端程序
  • 执行 TCPServer.__init__ 方法,创建服务端Socket对象并绑定 IP 和 端口
  • 执行 BaseServer.__init__ 方法,将自定义的继承自SocketServer.BaseRequestHandler 的类 MyRequestHandle赋值给 self.RequestHandlerClass
  • 执行 BaseServer.server_forever 方法,While 循环一直监听是否有客户端请求到达 ...
  • 当客户端连接到达服务器
  • 执行 ThreadingMixIn.process_request 方法,创建一个 “线程” 用来处理请求
  • 执行 ThreadingMixIn.process_request_thread 方法
  • 执行 BaseServer.finish_request 方法,执行 self.RequestHandlerClass()  即:执行 自定义 MyRequestHandler 的构造方法(自动调用基类BaseRequestHandler的构造方法,在该构造方法中又会调用 MyRequestHandler的handle方法)

ThreadingTCPServer相关源码:

class BaseServer:

    """Base class for server classes.

    Methods for the caller:

    - __init__(server_address, RequestHandlerClass)
    - serve_forever(poll_interval=0.5)
    - shutdown()
    - handle_request()  # if you do not use serve_forever()
    - fileno() -> int   # for select()

    Methods that may be overridden:

    - server_bind()
    - server_activate()
    - get_request() -> request, client_address
    - handle_timeout()
    - verify_request(request, client_address)
    - server_close()
    - process_request(request, client_address)
    - shutdown_request(request)
    - close_request(request)
    - handle_error()

    Methods for derived classes:

    - finish_request(request, client_address)

    Class variables that may be overridden by derived classes or
    instances:

    - timeout
    - address_family
    - socket_type
    - allow_reuse_address

    Instance variables:

    - RequestHandlerClass
    - socket

    """

    timeout = None

    def __init__(self, server_address, RequestHandlerClass):
        """Constructor.  May be extended, do not override."""
        self.server_address = server_address
        self.RequestHandlerClass = RequestHandlerClass
        self.__is_shut_down = threading.Event()
        self.__shutdown_request = False

    def server_activate(self):
        """Called by constructor to activate the server.

        May be overridden.

        """
        pass

    def serve_forever(self, poll_interval=0.5):
        """Handle one request at a time until shutdown.

        Polls for shutdown every poll_interval seconds. Ignores
        self.timeout. If you need to do periodic tasks, do them in
        another thread.
        """
        self.__is_shut_down.clear()
        try:
            while not self.__shutdown_request:
                # XXX: Consider using another file descriptor or
                # connecting to the socket to wake this up instead of
                # polling. Polling reduces our responsiveness to a
                # shutdown request and wastes cpu at all other times.
                r, w, e = _eintr_retry(select.select, [self], [], [],
                                       poll_interval)
                if self in r:
                    self._handle_request_noblock()
        finally:
            self.__shutdown_request = False
            self.__is_shut_down.set()

    def shutdown(self):
        """Stops the serve_forever loop.

        Blocks until the loop has finished. This must be called while
        serve_forever() is running in another thread, or it will
        deadlock.
        """
        self.__shutdown_request = True
        self.__is_shut_down.wait()

    # The distinction between handling, getting, processing and
    # finishing a request is fairly arbitrary.  Remember:
    #
    # - handle_request() is the top-level call.  It calls
    #   select, get_request(), verify_request() and process_request()
    # - get_request() is different for stream or datagram sockets
    # - process_request() is the place that may fork a new process
    #   or create a new thread to finish the request
    # - finish_request() instantiates the request handler class;
    #   this constructor will handle the request all by itself

    def handle_request(self):
        """Handle one request, possibly blocking.

        Respects self.timeout.
        """
        # Support people who used socket.settimeout() to escape
        # handle_request before self.timeout was available.
        timeout = self.socket.gettimeout()
        if timeout is None:
            timeout = self.timeout
        elif self.timeout is not None:
            timeout = min(timeout, self.timeout)
        fd_sets = _eintr_retry(select.select, [self], [], [], timeout)
        if not fd_sets[0]:
            self.handle_timeout()
            return
        self._handle_request_noblock()

    def _handle_request_noblock(self):
        """Handle one request, without blocking.

        I assume that select.select has returned that the socket is
        readable before this function was called, so there should be
        no risk of blocking in get_request().
        """
        try:
            request, client_address = self.get_request()
        except socket.error:
            return
        if self.verify_request(request, client_address):
            try:
                self.process_request(request, client_address)
            except:
                self.handle_error(request, client_address)
                self.shutdown_request(request)

    def handle_timeout(self):
        """Called if no new request arrives within self.timeout.

        Overridden by ForkingMixIn.
        """
        pass

    def verify_request(self, request, client_address):
        """Verify the request.  May be overridden.

        Return True if we should proceed with this request.

        """
        return True

    def process_request(self, request, client_address):
        """Call finish_request.

        Overridden by ForkingMixIn and ThreadingMixIn.

        """
        self.finish_request(request, client_address)
        self.shutdown_request(request)

    def server_close(self):
        """Called to clean-up the server.

        May be overridden.

        """
        pass

    def finish_request(self, request, client_address):
        """Finish one request by instantiating RequestHandlerClass."""
        self.RequestHandlerClass(request, client_address, self)

    def shutdown_request(self, request):
        """Called to shutdown and close an individual request."""
        self.close_request(request)

    def close_request(self, request):
        """Called to clean up an individual request."""
        pass

    def handle_error(self, request, client_address):
        """Handle an error gracefully.  May be overridden.

        The default is to print a traceback and continue.

        """
        print '-'*40
        print 'Exception happened during processing of request from',
        print client_address
        import traceback
        traceback.print_exc() # XXX But this goes to stderr!
        print '-'*40

BaseServer
BaseServer
class TCPServer(BaseServer):

    """Base class for various socket-based server classes.

    Defaults to synchronous IP stream (i.e., TCP).

    Methods for the caller:

    - __init__(server_address, RequestHandlerClass, bind_and_activate=True)
    - serve_forever(poll_interval=0.5)
    - shutdown()
    - handle_request()  # if you don't use serve_forever()
    - fileno() -> int   # for select()

    Methods that may be overridden:

    - server_bind()
    - server_activate()
    - get_request() -> request, client_address
    - handle_timeout()
    - verify_request(request, client_address)
    - process_request(request, client_address)
    - shutdown_request(request)
    - close_request(request)
    - handle_error()

    Methods for derived classes:

    - finish_request(request, client_address)

    Class variables that may be overridden by derived classes or
    instances:

    - timeout
    - address_family
    - socket_type
    - request_queue_size (only for stream sockets)
    - allow_reuse_address

    Instance variables:

    - server_address
    - RequestHandlerClass
    - socket

    """

    address_family = socket.AF_INET

    socket_type = socket.SOCK_STREAM

    request_queue_size = 5

    allow_reuse_address = False

    def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
        """Constructor.  May be extended, do not override."""
        BaseServer.__init__(self, server_address, RequestHandlerClass)
        self.socket = socket.socket(self.address_family,
                                    self.socket_type)
        if bind_and_activate:
            try:
                self.server_bind()
                self.server_activate()
            except:
                self.server_close()
                raise

    def server_bind(self):
        """Called by constructor to bind the socket.

        May be overridden.

        """
        if self.allow_reuse_address:
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind(self.server_address)
        self.server_address = self.socket.getsockname()

    def server_activate(self):
        """Called by constructor to activate the server.

        May be overridden.

        """
        self.socket.listen(self.request_queue_size)

    def server_close(self):
        """Called to clean-up the server.

        May be overridden.

        """
        self.socket.close()

    def fileno(self):
        """Return socket file number.

        Interface required by select().

        """
        return self.socket.fileno()

    def get_request(self):
        """Get the request and client address from the socket.

        May be overridden.

        """
        return self.socket.accept()

    def shutdown_request(self, request):
        """Called to shutdown and close an individual request."""
        try:
            #explicitly shutdown.  socket.close() merely releases
            #the socket and waits for GC to perform the actual close.
            request.shutdown(socket.SHUT_WR)
        except socket.error:
            pass #some platforms may raise ENOTCONN here
        self.close_request(request)

    def close_request(self, request):
        """Called to clean up an individual request."""
        request.close()
TCPServer
class ThreadingMixIn:
    """Mix-in class to handle each request in a new thread."""

    # Decides how threads will act upon termination of the
    # main process
    daemon_threads = False

    def process_request_thread(self, request, client_address):
        """Same as in BaseServer but as a thread.

        In addition, exception handling is done here.

        """
        try:
            self.finish_request(request, client_address)
            self.shutdown_request(request)
        except:
            self.handle_error(request, client_address)
            self.shutdown_request(request)

    def process_request(self, request, client_address):
        """Start a new thread to process the request."""
        t = threading.Thread(target = self.process_request_thread,
                             args = (request, client_address))
        t.daemon = self.daemon_threads
        t.start()
ThreadingMixIn
class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass
ThreadingTCPServer

ForkingTCPServer

ForkingTCPServer和ThreadingTCPServer的使用和执行流程基本一致,只不过在内部分别为请求者建立 “线程”  和 “进程”。

基本使用:

import socketserver

class MyTCPHandler(socketserver.BaseRequestHandler):
    """
    The request handler class for our server.

    It is instantiated once per connection to the server, and must
    override the handle() method to implement communication to the
    client.
    """

    def handle(self):
        # self.request is the TCP socket connected to the client
        while True:
            try:
                self.data = self.request.recv(1024).strip()
                print("{} wrote:".format(self.client_address[0]))
                print(self.data)
                # just send back the same data, but upper-cased
                self.request.sendall(self.data.upper())
            except ConnectionResetError as e:
                print("error:",e)
                break

if __name__ == "__main__":
    HOST, PORT = "localhost", 9999

    # Create the server, binding to localhost on port 9999
    server = socketserver.ForkingTCPServer((HOST, PORT), MyTCPHandler)

    # Activate the server; this will keep running until you
    # interrupt the program with Ctrl-C
    server.serve_forever()
服务端
import socket

client = socket.socket()
client.connect(("localhost",9999))

while True:
    data = input(">>>").strip()
    client.send(data.encode("utf-8"))
    recv_data = client.recv(1024).decode()
    print(recv_data)
客户端

 

posted @ 2017-08-07 20:57  Iron_boy  阅读(130)  评论(0编辑  收藏  举报