Day 32 作业(实现多线程并发FTP)
服务端:
# server
import socket
import pickle
import struct
import interface
server = socket.socket()
server.bind(
('127.0.0.1', 8080)
)
server.listen(3)
func_dic = {
'register': interface.register_interface,
'login': interface.login_interface,
'upload': interface.upload_interface,
'check_file': interface.check_file_interface,
'download': interface.download_interface
}
def choose_interface(conn, addr):
while True:
try:
header = conn.recv(4)
data_len = struct.unpack('i', header)[0]
recv_dic = pickle.loads(conn.recv(data_len))
type1 = recv_dic['type']
print(type1)
if type1 in func_dic:
func_dic[type1](recv_dic, conn)
except Exception as e:
print(f'{addr}断开连接...')
break
if __name__ == '__main__':
while True:
conn, addr = server.accept()
print(f'{addr}连接中...')
t = Thread(target=choose_interface, args=(conn, addr))
t.start()
# interface
from db import models
import common
import os
from db.db_handler import DB_PATH
def register_interface(recv_dic, conn):
user_list = models.Users.select(recv_dic['username'])
if user_list:
common.send_back({
'flag': False, 'msg': f'用户{recv_dic["username"]}已存在!'
}, conn)
models.Users(recv_dic['username'], recv_dic['password'])
back_dic = {"flag": True, "msg": "注册成功!"}
common.send_back(back_dic, conn)
def login_interface(recv_dic, conn):
user_obj = models.Users.select(recv_dic['username'])
if not user_obj:
common.send_back({
'flag': False, 'msg': '用户不存在!'
}, conn)
if user_obj.pwd == recv_dic['password']:
back_dic = {"flag": True, "msg": "登录成功!"}
common.send_back(back_dic, conn)
else:
back_dic = {"flag": False, "msg": "密码错误!"}
common.send_back(back_dic, conn)
def upload_interface(recv_dic, conn):
file_name = recv_dic['file_name']
file_size = recv_dic['file_size']
init_size = 0
file_path = os.path.join(os.path.join(os.path.join(DB_PATH, 'USERS'), recv_dic['username']), file_name)
print(file_path)
# if not os.path.exists(file_path):
with open(file_path, 'wb') as f:
while init_size < file_size:
data = conn.recv(1024)
f.write(data)
init_size += len(data)
common.send_back({'flag': True, 'msg': f'文件{file_name}上传成功!'}, conn)
# else:
# common.send_back({'flag': False, 'msg': '文件已存在!'}, conn)
def check_file_interface(recv_dic, conn):
users_dir = os.path.join(DB_PATH, 'USERS')
if os.path.isdir(users_dir):
user_dir = os.path.join(users_dir, recv_dic['username'])
if os.path.isdir(user_dir):
file_list = os.listdir(user_dir)
if file_list[:]:
back_dic = {'flag': True, 'msg': file_list[:]}
common.send_back(back_dic, conn)
else:
back_dic = {'flag': False, 'msg': '没有文件!'}
common.send_back(back_dic, conn)
else:
back_dic = {'flag': False, 'msg': '没有用户!'}
common.send_back(back_dic, conn)
def download_interface(recv_dic, conn):
file_name = recv_dic['file_name']
username = recv_dic['username']
file_path = os.path.join(os.path.join(os.path.join(DB_PATH, 'USERS'), username), file_name)
file_size = os.path.getsize(file_path)
back_dic = {
'file_name': file_name,
'file_size': file_size
}
common.send_back(back_dic, conn)
init_size = 0
with open(file_path, 'rb') as f:
while init_size < file_size:
data = f.read(1024)
conn.send(data)
init_size += len(data)
# db_handler
import os
import pickle
BASE_PATH = os.path.dirname(os.path.dirname(__file__))
DB_PATH = os.path.join(BASE_PATH, 'db')
def save_obj(obj):
users_dir = os.path.join(DB_PATH, obj.__class__.__name__.upper())
if not os.path.isdir(users_dir):
os.mkdir(users_dir)
user_dir = os.path.join(users_dir, obj.name)
if not os.path.isdir(user_dir):
os.mkdir(user_dir)
user_path = os.path.join(user_dir, obj.name)
with open(user_path, 'wb') as f:
pickle.dump(obj, f)
f.flush()
def select_obj(cls, name):
users_dir = os.path.join(DB_PATH, cls.__name__.upper())
if os.path.isdir(users_dir):
user_dir = os.path.join(users_dir, name)
if os.path.isdir(user_dir):
user_path = os.path.join(user_dir, name)
if os.path.exists(user_path):
with open(user_path, 'rb') as f:
return pickle.load(f)
# models
from db import db_handler
class Users:
def __init__(self, username, pwd):
self.name = username
self.pwd = pwd
self.save()
def save(self):
db_handler.save_obj(self)
@classmethod
def select(cls, name):
return db_handler.select_obj(cls, name)
客户端:
# start
import src
if __name__ == '__main__':
src.run()
# client1
import socket
import json
import struct
def get_client():
client = socket.socket()
client.connect(
('127.0.0.1', 8080)
)
return client
# src
import common
import client1
import os
from db.db_handler import BASE_PATH
user_info = {
'user': None
}
def register(client):
while True:
username = input('请输入用户名:').strip()
password = input('请输入密码:').strip()
re_password = input('请确认密码:').strip()
if password == re_password:
send_dic = {
'type': 'register',
'username': username,
'password': password
}
recv_dic = common.send_m(send_dic, client)
if recv_dic['flag']:
print(recv_dic['msg'])
break
else:
print('密码不一致!')
def login(client):
while True:
username = input('请输入用户名:').strip()
password = input('请输入密码:').strip()
send_dic = {
'type': 'login',
'username': username,
'password': password
}
recv_dic = common.send_m(send_dic, client)
if recv_dic['flag']:
user_info['user'] = username
print(recv_dic['msg'])
user_view(client)
break
else:
print(recv_dic['msg'])
func_dic = {
'1': register,
'2': login,
}
def run():
client_0 = client1.get_client()
while True:
print('''
1 注册
2 登录
q 退出
''')
choice = input('请选择功能编号:').strip()
if choice == 'q':
break
if choice not in func_dic:
print('请输入正确编号!')
continue
func_dic[choice](client_0)
def upload(client):
while True:
file_dir = os.path.join(BASE_PATH, 'upload_file')
user_file_list = os.listdir(file_dir)
for index, file in enumerate(user_file_list):
print(index, file)
choice = input('请输入文件编号:').strip()
if not choice.isdigit():
print('请输入数字!')
continue
choice = int(choice)
if choice not in range(len(user_file_list)):
print('请输入正确!')
continue
file_name = user_file_list[choice]
file_path = os.path.join(file_dir, file_name)
send_dic = {
'type': 'upload',
'username': user_info.get('user'),
'file_name': file_name,
'file_size': os.path.getsize(file_path)
}
# print(send_dic)
recv_dic = common.send_m_file(send_dic, client, file_path)
if recv_dic['flag']:
print(recv_dic['msg'])
break
else:
print(recv_dic['msg'])
def download(client):
while True:
send_dic = {
'type': 'check_file',
'username': user_info.get('user')
}
recv_dic = common.send_m(send_dic, client)
if not recv_dic['flag']:
print(recv_dic['msg'])
file_list = recv_dic['msg']
for index, file in enumerate(file_list):
print(index, file)
choice = input('请输入文件编号:').strip()
if not choice.isdigit():
print('请输入数字!')
continue
choice = int(choice)
if choice not in range(len(file_list)):
print('请输入正确编号!')
continue
file_name = file_list[choice]
send_dic1 = {
'type': 'download',
'file_name': file_name,
'username': user_info.get('user')
}
recv_dic = common.send_m(send_dic1, client)
file_size = recv_dic['file_size']
file_path = os.path.join(os.path.join(BASE_PATH, 'upload_file'), file_name)
init_size = 0
with open(file_path, 'wb') as f:
while init_size < file_size:
data = client.recv(1024)
f.write(data)
init_size += len(data)
print(f'{file_name}下载完成!')
break
func_dic1 = {
'1': upload,
'2': download
}
def user_view(client): # (******************必须传入当前client,不能重新生成!!!!)
while True:
print('''
1 上传
2 下载
q 注销
''')
choice = input('请输入功能编码:').strip()
if choice == 'q':
user_info['user'] = None
break
if choice not in func_dic1:
print('请输入正确编号!')
continue
func_dic1[choice](client)
# common
# 服务端也要有
import struct
import pickle
import os
def send_m(send_dic, client):
bytes_data = pickle.dumps(send_dic)
# bytes_data = pickle_data.encode('utf8')
header = struct.pack('i', len(bytes_data))
client.send(header)
client.send(bytes_data)
# 接收服务端消息
header = client.recv(4)
recv_len = struct.unpack('i', header)[0]
data_dic = pickle.loads(client.recv(recv_len))
return data_dic
def send_m_file(send_dic, client, file):
bytes_data = pickle.dumps(send_dic)
# bytes_data = pickle_data.encode('utf8')
header = struct.pack('i', len(bytes_data))
client.send(header)
client.send(bytes_data)
# print(file)
init_data = 0
# file_name = send_dic['file_name']
file_size = send_dic['file_size']
# print(file_size)
with open(file, 'rb') as f:
while init_data < file_size:
data = f.read(1024)
print(data)
client.send(data)
init_data += len(data)
# 接收服务端消息
header = client.recv(4)
recv_len = struct.unpack('i', header)[0]
data_dic = pickle.loads(client.recv(recv_len))
return data_dic
def send_back(back_dic, conn):
bytes_data = pickle.dumps(back_dic)
header = struct.pack('i', len(bytes_data))
conn.send(header)
conn.send(bytes_data)