asyncio 简单的 websocket 服务器

pip install websockets~=9.1

一个代理服务器,将所有客户端发来的信息,转发到其他服务器。

import asyncio
import json
import logging

import websockets
import websockets.legacy.client


MODULE_LOGGER = logging.getLogger(__name__)

# 需要连接的其他服务器地址
SERVER_TYPES = {
    'audio_server': {'addr': 'ws://localhost:9999', 'msgs': None},
    'touch_server': {'addr': 'ws://localhost:8888', 'msgs': None}
}

# 转发服务器,用来接收用户请求
class Dispatcher:
    """A server, would wait for clients to connect, and put the client's message into different Queue.
    """

    def __init__(self, ip="localhost", port=9000):
        self._host = ip
        self._port = port
        self.clients = []
        self._loop = asyncio.get_event_loop()
        self._loop.create_task(self.run())

    async def run(self):
        """create a server
        """
        self.server = await websockets.serve(self.dispatch, self._host, self._port)

    async def register_client(self, ws):
        """register connected client

        :param ws: websocket client
        """
        if ws not in self.clients:
            self.clients.append(ws)

    async def unregister_client(self, ws):
        """unregister a disconnected client.

        :param ws: websocket client.
        :return:
        """
        if ws in self.clients:
            self.clients.remove(ws)

    async def dispatch(self, websocket, path):
        """According to client's request's server type, put the request message into specific queue.

        :param websocket:
        :param path:
        :return:
        """
        await self.register_client(websocket)
        try:
            async for message in websocket:  # recv message from a websocket client.
                msg_dict = self.parse_message(message)
                server_type = msg_dict.get('server')
                if server_type in SERVER_TYPES:
                    await SERVER_TYPES[server_type]['msgs'].put((websocket, message))  # put message info into queue
                    MODULE_LOGGER.info('Add one message to queue.')
        except Exception as exc:
            MODULE_LOGGER.error(exc)
        finally:
            await self.unregister_client(websocket)

    def parse_message(self, message):
        """parse the str message into dict type.
            message should be like this:
                msg = {
                    'server': 'audio_server',  # touch_server, audio_server,...
                    'method': 'play_audio',
                    ...
                }
        :param message:
        :return:
        """
        try:
            message = json.loads(message)
            return message
        except Exception as exe:
            MODULE_LOGGER.info(exe)
            return {}

# 客户端,用来连接其他的服务器,然后将用户的请求,发送到这个服务器

class Client:
    """local client that connect to a specific server, and send/recv message to/from that server."""
    def __init__(self, server, addr, msgs):
        """

        :param server: a Dispatcher instance.
        :param addr: the server's uri address.
        :param msgs: the message queue belong to the server type.
        """
        self.server = server
        self._remote_addr = addr
        self._msgs = msgs

        self._loop = asyncio.get_event_loop()
        self._loop.create_task(self.run())

    async def run(self):
        await self._connect()
        self._loop.create_task(self.send_message())
        self._loop.create_task(self.recv_message())

    async def _connect(self):
        try:
            self.client = await websockets.legacy.client.connect(self._remote_addr)
        except:
            MODULE_LOGGER.error(f"Cannot connect to: {self._remote_addr}")

    def get_user_client(self, address):
        """According to the argument: address, get the user client object

        :param list address: client's address
        :return:
        """
        for client in self.server.clients:
            if list(client.remote_address) == address:
                return client
        return None

    async def recv_message(self):
        """recv feedback from server, then send the feedback to user client."""
        try:
            while True:
                msg = await self.client.recv()
                msg_json = json.loads(msg)
                to_client = msg_json['from']  # the server should send the 'from' keyword.
                MODULE_LOGGER.info(f'recv msg :{msg}')
                client = self.get_user_client(to_client)
                if client:
                    await client.send(msg)
        except:
            await self._connect()  # reconnect server
            self._loop.create_task(self.send_message())


    async def send_message(self):
        """Send message to the server(self._remote_addr)
        :param addr:
        :param msg_queue:
        :return:
        """
        while True:
            client_info = await self._msgs.get()  # tuple(client_obj, message)
            msg = self.add_msg_source(*client_info)
            try:
                await self.client.send(msg)
                MODULE_LOGGER.info(f"Send message to {self._remote_addr}")
            except Exception as exc:
                MODULE_LOGGER.error(f"Connect to server: {self._remote_addr} error!")
                MODULE_LOGGER.error(exc)
                # local client can't connect to server, then feedback to user client.
                user_client = client_info[0]
                if user_client.open:
                    await user_client.send("Error send message ...")  # Todo. The sent error data should be re-considered
                await self._connect()  # try to reconnect the server.
                self._loop.create_task(self.send_message())
                break

    def add_msg_source(self, client, msg):
        """add the message a from source, so that we can recognize the message from which client.

        :param client:
        :param msg:
        :return:
        """
        msg_json = json.loads(msg)
        msg_json['from'] = client.remote_address
        return json.dumps(msg_json)


def init_message_queue(maxsize=5):
    """ initialize SERVER_TYPES, create asyncio.Queue() for each type of server to store message.
    
    :return: 
    """
    global SERVER_TYPES
    for server in SERVER_TYPES:
        SERVER_TYPES[server]['msgs'] = asyncio.Queue(maxsize=maxsize)


if __name__ == "__main__":
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    init_message_queue()
    server = Dispatcher()
    for key in SERVER_TYPES:
        Client(server, SERVER_TYPES[key]['addr'], SERVER_TYPES[key]['msgs'])

    loop.run_forever()

创建其他服务器:

import asyncio

import websockets


async def hello(websocket, path):
    try:
        print("Client connected...")
        async for message in websocket:
            print(message)
            await websocket.send(message)
            print(f"sent {message}")
    except:
        print('Client disconnected...')


async def main(port):
    await websockets.serve(hello, 'localhost', port)

asyncio.get_event_loop().create_task(main(9999))
asyncio.get_event_loop().create_task(main(8888))
asyncio.get_event_loop().run_forever()

模拟用户连接:

import os
import json
import asyncio

import websockets


async def recv_from(ws):
    while True:
        msg = await ws.recv()
        print(f'recv from {msg}')

msg = {
        'server': 'audio_server',  # touch_server, audio_server
        'method': 'play',
        'num': 0,
    }
msg2 = {
        'server': 'touch_server',  # touch_server, audio_server
        'method': 'play',
        'num': 0,
    }

async def hello(msg):

    async with websockets.connect('ws://localhost:9000') as websocket:
        asyncio.get_event_loop().create_task(recv_from(websocket))
        while True:
            await asyncio.sleep(1)
            msg['num'] += 1
            str_msg = json.dumps(msg)
            await websocket.send(str_msg)
            print("send to", msg)


asyncio.get_event_loop().create_task(hello(msg))
# asyncio.get_event_loop().create_task(hello(msg2))
asyncio.get_event_loop().run_forever()

上面共有三个脚本,执行顺序是:2,1,3

posted @ 2022-01-05 16:39  wztshine  阅读(746)  评论(0编辑  收藏  举报