import selectors
import threading
import socket
import datetime
import logging
from queue import Queue


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(thread)d %(message)s")


class Conn:
    def __init__(self, conn: socket.socket, handle):
        self.queue = Queue()
        self.conn = conn
        self.handle = handle


class ChatServer:
    def __init__(self, ip="127.0.0.1", port=60000):
        self.ip = ip
        self.port = port
        self.addr = ip, port
        self.sock = socket.socket()
        self.selector = selectors.DefaultSelector()
        self.clients = {}
        self.is_shutdown = threading.Event()

    def start(self):
        self.sock.bind(self.addr)
        self.sock.listen(128)
        self.sock.setblocking(False)
        self.selector.register(self.sock, selectors.EVENT_READ, self._accept)
        threading.Thread(target=self._run, daemon=True).start()

    def _run(self):
        while not self.is_shutdown.is_set():
            events = self.selector.select()
            for key, mask in events:
                callback = key.data
                if callable(callback):
                    callback(key.fileobj)
                else:
                    callback.handle(key.fileobj, mask)

    def _accept(self, sock: socket.socket()):
        # todo: accept
        conn, client = sock.accept()
        self.clients[client] = Conn(conn, self._handle)
        conn.setblocking(False)
        self.selector.register(conn, selectors.EVENT_READ | selectors.EVENT_WRITE, self.clients[client])

    def _handle(self, conn: socket.socket, mask):
        # todo: recv and send
        if mask & selectors.EVENT_READ:
            try:
                data = conn.recv(1024)
            except Exception as e:
                logging.info(e)
                data = b"quit"
            if data == b"quit":
                conn.close()
                self.clients.pop(conn.getpeername())
                return
            try:
                msg = data.decode("gbk")
            except UnicodeDecodeError:
                msg = data.decode()
            msg = "{:%Y/%m/%d %H:%M:%S} {}:{}\r\n {}".format(datetime.datetime.now(), *conn.getpeername(), msg)
            logging.info(msg)
            try:
                msg = msg.encode("gbk")
            except UnicodeEncodeError:
                msg = msg.encode()
            for c in self.clients.values():
                c.queue.put(msg)

        if mask & selectors.EVENT_WRITE:
            remote = self.clients[conn.getpeername()]
            while not remote.queue.empty():
                conn.send(remote.queue.get())

    def stop(self):
        self.is_shutdown.set()
        keys = []
        for fd, key in self.selector.get_map().items():
            key.fileobj.close()
            keys.append(fd)
        for x in keys:
            self.selector.unregister(x)
        self.selector.close()


def main():
    e = threading.Event()
    cs = ChatServer()
    cs.start()
    while not e.wait(1):
        cmd = input(">>>").strip()
        if cmd == "quit":
            cs.stop()
            e.wait(3)
            break


if __name__ == '__main__':
    main()