Tiny_Lu
不忘初心

Day 32 作业(实现多线程并发FTP)

服务端:

# server
import socket
import pickle
import struct
import interface

server = socket.socket()

server.bind(
    ('127.0.0.1', 8080)
)
server.listen(3)

func_dic = {
            'register': interface.register_interface,
            'login': interface.login_interface,
            'upload': interface.upload_interface,
            'check_file': interface.check_file_interface,
            'download': interface.download_interface
        }


def choose_interface(conn, addr):
    while True:
        try:
            header = conn.recv(4)
            data_len = struct.unpack('i', header)[0]
            recv_dic = pickle.loads(conn.recv(data_len))
            type1 = recv_dic['type']
            print(type1)

            if type1 in func_dic:
                func_dic[type1](recv_dic, conn)
        except Exception as e:
            print(f'{addr}断开连接...')
            break


if __name__ == '__main__':
    while True:
        conn, addr = server.accept()
        print(f'{addr}连接中...')
        t = Thread(target=choose_interface, args=(conn, addr))
        t.start()
# interface
from db import models
import common
import os
from db.db_handler import DB_PATH


def register_interface(recv_dic, conn):
    user_list = models.Users.select(recv_dic['username'])
    if user_list:
        common.send_back({
            'flag': False, 'msg': f'用户{recv_dic["username"]}已存在!'
        }, conn)
    models.Users(recv_dic['username'], recv_dic['password'])

    back_dic = {"flag": True, "msg": "注册成功!"}
    common.send_back(back_dic, conn)


def login_interface(recv_dic, conn):
    user_obj = models.Users.select(recv_dic['username'])
    if not user_obj:
        common.send_back({
            'flag': False, 'msg': '用户不存在!'
        }, conn)

    if user_obj.pwd == recv_dic['password']:
        back_dic = {"flag": True, "msg": "登录成功!"}
        common.send_back(back_dic, conn)
    else:
        back_dic = {"flag": False, "msg": "密码错误!"}
        common.send_back(back_dic, conn)


def upload_interface(recv_dic, conn):
    file_name = recv_dic['file_name']
    file_size = recv_dic['file_size']

    init_size = 0
    file_path = os.path.join(os.path.join(os.path.join(DB_PATH, 'USERS'), recv_dic['username']), file_name)
    print(file_path)
    # if not os.path.exists(file_path):
    with open(file_path, 'wb') as f:
        while init_size < file_size:
            data = conn.recv(1024)
            f.write(data)
            init_size += len(data)

    common.send_back({'flag': True, 'msg': f'文件{file_name}上传成功!'}, conn)
    # else:
    #     common.send_back({'flag': False, 'msg': '文件已存在!'}, conn)


def check_file_interface(recv_dic, conn):
    users_dir = os.path.join(DB_PATH, 'USERS')
    if os.path.isdir(users_dir):
        user_dir = os.path.join(users_dir, recv_dic['username'])
        if os.path.isdir(user_dir):
            file_list = os.listdir(user_dir)
            if file_list[:]:
                back_dic = {'flag': True, 'msg': file_list[:]}
                common.send_back(back_dic, conn)
            else:
                back_dic = {'flag': False, 'msg': '没有文件!'}
                common.send_back(back_dic, conn)
        else:
            back_dic = {'flag': False, 'msg': '没有用户!'}
            common.send_back(back_dic, conn)


def download_interface(recv_dic, conn):
    file_name = recv_dic['file_name']
    username = recv_dic['username']
    file_path = os.path.join(os.path.join(os.path.join(DB_PATH, 'USERS'), username), file_name)
    file_size = os.path.getsize(file_path)
    
    back_dic = {
        'file_name': file_name,
        'file_size': file_size
    }
    common.send_back(back_dic, conn)
    
    init_size = 0
    with open(file_path, 'rb') as f:
        while init_size < file_size:
            data = f.read(1024)
            conn.send(data)
            init_size += len(data)
# db_handler
import os
import pickle

BASE_PATH = os.path.dirname(os.path.dirname(__file__))

DB_PATH = os.path.join(BASE_PATH, 'db')


def save_obj(obj):
    users_dir = os.path.join(DB_PATH, obj.__class__.__name__.upper())
    if not os.path.isdir(users_dir):
        os.mkdir(users_dir)
    user_dir = os.path.join(users_dir, obj.name)
    if not os.path.isdir(user_dir):
        os.mkdir(user_dir)
    user_path = os.path.join(user_dir, obj.name)
    with open(user_path, 'wb') as f:
        pickle.dump(obj, f)
        f.flush()


def select_obj(cls, name):
    users_dir = os.path.join(DB_PATH, cls.__name__.upper())
    if os.path.isdir(users_dir):
        user_dir = os.path.join(users_dir, name)
        if os.path.isdir(user_dir):
            user_path = os.path.join(user_dir, name)
            if os.path.exists(user_path):
                with open(user_path, 'rb') as f:
                    return pickle.load(f)
# models
from db import db_handler


class Users:
    def __init__(self, username, pwd):
        self.name = username
        self.pwd = pwd
        self.save()
        
    def save(self):
        db_handler.save_obj(self)
        
    @classmethod
    def select(cls, name):
        return db_handler.select_obj(cls, name)

客户端:

# start
import src


if __name__ == '__main__':
    src.run()
# client1
import socket
import json
import struct


def get_client():
    client = socket.socket()

    client.connect(
        ('127.0.0.1', 8080)
    )
    return client
# src
import common
import client1
import os
from db.db_handler import BASE_PATH

user_info = {
    'user': None
}


def register(client):
    while True:
        username = input('请输入用户名:').strip()
        password = input('请输入密码:').strip()
        re_password = input('请确认密码:').strip()

        if password == re_password:
            send_dic = {
                'type': 'register',
                'username': username,
                'password': password
            }
            recv_dic = common.send_m(send_dic, client)

            if recv_dic['flag']:
                print(recv_dic['msg'])
                break

        else:
            print('密码不一致!')


def login(client):
    while True:
        username = input('请输入用户名:').strip()
        password = input('请输入密码:').strip()

        send_dic = {
            'type': 'login',
            'username': username,
            'password': password
        }

        recv_dic = common.send_m(send_dic, client)

        if recv_dic['flag']:
            user_info['user'] = username
            print(recv_dic['msg'])
            user_view(client)
            break
        else:
            print(recv_dic['msg'])


func_dic = {
    '1': register,
    '2': login,
}


def run():
    client_0 = client1.get_client()
    while True:
        print('''
        1 注册
        2 登录
        q 退出
        ''')

        choice = input('请选择功能编号:').strip()

        if choice == 'q':
            break

        if choice not in func_dic:
            print('请输入正确编号!')
            continue

        func_dic[choice](client_0)


def upload(client):
    while True:
        file_dir = os.path.join(BASE_PATH, 'upload_file')
        user_file_list = os.listdir(file_dir)
        for index, file in enumerate(user_file_list):
            print(index, file)

        choice = input('请输入文件编号:').strip()

        if not choice.isdigit():
            print('请输入数字!')
            continue

        choice = int(choice)

        if choice not in range(len(user_file_list)):
            print('请输入正确!')
            continue

        file_name = user_file_list[choice]
        file_path = os.path.join(file_dir, file_name)

        send_dic = {
            'type': 'upload',
            'username': user_info.get('user'),
            'file_name': file_name,
            'file_size': os.path.getsize(file_path)
        }
        # print(send_dic)
        recv_dic = common.send_m_file(send_dic, client, file_path)

        if recv_dic['flag']:
            print(recv_dic['msg'])
            break
        else:
            print(recv_dic['msg'])


def download(client):
    while True:
        send_dic = {
            'type': 'check_file',
            'username': user_info.get('user')
        }

        recv_dic = common.send_m(send_dic, client)

        if not recv_dic['flag']:
            print(recv_dic['msg'])

        file_list = recv_dic['msg']
        for index, file in enumerate(file_list):
            print(index, file)

        choice = input('请输入文件编号:').strip()

        if not choice.isdigit():
            print('请输入数字!')
            continue

        choice = int(choice)

        if choice not in range(len(file_list)):
            print('请输入正确编号!')
            continue

        file_name = file_list[choice]

        send_dic1 = {
            'type': 'download',
            'file_name': file_name,
            'username': user_info.get('user')
        }
        recv_dic = common.send_m(send_dic1, client)
        
        file_size = recv_dic['file_size']

        file_path = os.path.join(os.path.join(BASE_PATH, 'upload_file'), file_name)
        init_size = 0
        with open(file_path, 'wb') as f:
            while init_size < file_size:
                data = client.recv(1024)
                f.write(data)
                init_size += len(data)
        
        print(f'{file_name}下载完成!')
        break


func_dic1 = {
    '1': upload,
    '2': download
}


def user_view(client):  # (******************必须传入当前client,不能重新生成!!!!)
    while True:
        print('''
        1 上传
        2 下载
        q 注销
        ''')

        choice = input('请输入功能编码:').strip()

        if choice == 'q':
            user_info['user'] = None
            break

        if choice not in func_dic1:
            print('请输入正确编号!')
            continue

        func_dic1[choice](client)
# common
# 服务端也要有
import struct
import pickle
import os


def send_m(send_dic, client):
    bytes_data = pickle.dumps(send_dic)
    # bytes_data = pickle_data.encode('utf8')

    header = struct.pack('i', len(bytes_data))
    client.send(header)
    client.send(bytes_data)

    # 接收服务端消息
    header = client.recv(4)
    recv_len = struct.unpack('i', header)[0]
    data_dic = pickle.loads(client.recv(recv_len))
    return data_dic


def send_m_file(send_dic, client, file):
    bytes_data = pickle.dumps(send_dic)
    # bytes_data = pickle_data.encode('utf8')

    header = struct.pack('i', len(bytes_data))
    client.send(header)
    client.send(bytes_data)
    # print(file)
    init_data = 0
    # file_name = send_dic['file_name']
    file_size = send_dic['file_size']
    # print(file_size)
    with open(file, 'rb') as f:
        while init_data < file_size:
            data = f.read(1024)
            print(data)
            client.send(data)
            init_data += len(data)

    # 接收服务端消息
    header = client.recv(4)
    recv_len = struct.unpack('i', header)[0]
    data_dic = pickle.loads(client.recv(recv_len))
    return data_dic


def send_back(back_dic, conn):
    bytes_data = pickle.dumps(back_dic)
    header = struct.pack('i', len(bytes_data))
    conn.send(header)
    conn.send(bytes_data)
posted @ 2019-10-27 14:39  二二二二白、  阅读(235)  评论(0编辑  收藏  举报