基于TCP协议的socket通信实现

Posted on 2022-11-27 21:20  呱呱呱呱叽里呱啦  阅读(32)  评论(0编辑  收藏  举报

基于TCP协议的socket通信实现

服务端

import socketserver
import subprocess
import struct
import os
import json


def send_file(cmd, connect_obj, file_exist_dir='server_file'): # 自定义同级目录作为文件目录
    file_name = cmd[13:].strip()

    file_dir = os.path.join(os.path.dirname(__file__), file_exist_dir, file_name)

    if os.path.isfile(file_dir):
        total_size = os.path.getsize(file_dir)
        md5_value = ''

        header = {
            'action': 'save_file',
            'filename': file_name,
            'total_szie': total_size,
            'md5': md5_value

        }

        header_bytes = json.dumps(header).encode('utf-8')
        header_lenth = struct.pack('i', len(header_bytes))
        connect_obj.send(header_lenth)
        connect_obj.send(header_bytes)
        with open(file_dir, 'rb') as f:
            for line in f:
                connect_obj.send(line)
        result = connect_obj.recv(1024).decode('utf-8')
    else:
        header_bytes = json.dumps('file not found...').encode('utf-8')
        header_lenth = struct.pack('i', len(header_bytes))
        connect_obj.send(header_lenth)
        connect_obj.send(header_bytes)
        result = f'{file_dir} not found...'
    print(result)

def recv_file(connect_obj, file_save_dir='server_file'): # 自定义同级目录作为文件目录
    header_lenth = connect_obj.recv(4)
    header_szie = struct.unpack('i', header_lenth)[0]
    header_bytes = connect_obj.recv(header_szie)
    header = json.loads(header_bytes.decode('utf-8'))
    if not header == 'file not found...':
        recv_size = 0
        file_dir = os.path.join(file_save_dir, header['filename'])
        with open(file_dir, 'wb') as f:
            while recv_size < header['total_szie']:
                flow = connect_obj.recv(1024)
                f.write(flow)
                recv_size += len(flow)
        connect_obj.send(b'done')
        result = f'{file_dir} saved'
    else:
        result = header

    print(result)


class MyRequestHandle(socketserver.BaseRequestHandler):
    def handle(self):

        print(self.client_address)

        self.authenticated = False

        cmd = self.request.recv(24).decode('utf-8')

        if cmd == 'credit': # 在此处自定义通行凭证
            self.authenticated = True
            self.request.send(struct.pack('i', len(b'waiting for order...')))
            self.request.send(b'waiting for order...')
        else:
            self.request.send(struct.pack('i', len(b'password required...')))
            self.request.send(b'password required & connection closing...')

        while self.authenticated:
            try:
                cmd = self.request.recv(1024).decode('utf-8')
                if len(cmd) == 0 or cmd == 'exit':
                    print(f'{self.client_address} 已断开连接')
                    break
                elif cmd.startswith('download_file'):
                    send_file(cmd, self.request)
                elif cmd.startswith('upload_file'):
                    recv_file(self.request)

                else:
                    obj = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                    stdout_res = obj.stdout.read()
                    stderr_res = obj.stderr.read()
                    res = stdout_res + stderr_res
                    res_size = struct.pack('i', len(res))
                    self.request.send(res_size)
                    self.request.send(res)




            except ConnectionResetError:
                break
        self.request.close()


s = socketserver.ThreadingTCPServer(('127.0.0.1', 8000), MyRequestHandle) # 在此处自定义IP和PORT

s.serve_forever()

客户端

import struct
import json
import os
from socket import *

def send_file(cmd, connect_obj, file_exist_dir='client_file'): # 此处自定义同级目录作为文件目录
    file_name = cmd[11:].strip()

    file_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_exist_dir, file_name)

    if os.path.isfile(file_dir):
        total_size = os.path.getsize(file_dir)
        md5_value = ''

        header = {
            'action': 'save_file',
            'filename': file_name,
            'total_szie': total_size,
            'md5': md5_value

        }

        header_bytes = json.dumps(header).encode('utf-8')
        header_lenth = struct.pack('i', len(header_bytes))
        connect_obj.send(header_lenth)
        connect_obj.send(header_bytes)
        with open(file_dir, 'rb') as f:
            for line in f:
                connect_obj.send(line)
        result = connect_obj.recv(1024).decode('utf-8')
    else:
        header_bytes = json.dumps('file not found...').encode('utf-8')
        header_lenth = struct.pack('i', len(header_bytes))
        connect_obj.send(header_lenth)
        connect_obj.send(header_bytes)
        result = f'{file_dir} not found...'

    print(result)


def recv_file(connect_obj, file_save_dir='client_file'): # 此处自定义同级目录作为文件目录
    header_lenth = connect_obj.recv(4)
    header_szie = struct.unpack('i', header_lenth)[0]
    header_bytes = connect_obj.recv(header_szie)
    header = json.loads(header_bytes.decode('utf-8'))
    if not header == 'file not found...':
        recv_size = 0
        file_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),file_save_dir, header['filename'])
        with open(file_dir, 'wb') as f:
            while recv_size < header['total_szie']:
                flow = connect_obj.recv(1024)
                f.write(flow)
                recv_size += len(flow)
        connect_obj.send(b'done')
        result = f'{file_dir} saved'
    else:
        result = header

    print(result)

client = socket(AF_INET, SOCK_STREAM)

client.connect(('127.0.0.1', 8000)) # 此处自定义IP和PORT



while True:
    msg = input('请输入要发送的命令:').strip()
    if len(msg) == 0: continue

    client.send(msg.encode('utf-8'))

    if msg == 'exit':
        client.close()
        break
    elif msg.startswith('download_file'):
        recv_file(client)


    elif msg.startswith('upload_file'):
        send_file(msg, client)

    else:
        cmd_res_size = struct.unpack('i', client.recv(4))[0]
        cmd_res = b''
        while len(cmd_res) < cmd_res_size:
            cmd_res += client.recv(1024)
        print(cmd_res.decode('utf-8'))
        if cmd_res.decode('utf-8') == 'password required & connection closing...':
            client.close()
            break