泽桐-Merlin

They built goals around their dreams and nerver quit.

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

原多线程版FTP程序:http://www.cnblogs.com/linzetong/p/8290378.html

只需要在原来的代码基础上稍作修改:

一、gevent协程版本

1、 导入gevent模块 

import gevent

 

2、python的异步库gevent打猴子补丁,他的用途是让你方便的导入非阻塞的模块,不需要特意的去引入。

 #注意:from gevent import monkey;monkey.patch_all()必须放到被打补丁者的前面,如time,socket模块之前

from gevent import monkey;monkey.patch_all()

 

3、 把socket设置为非阻塞

self.sock.setblocking(0)  

 

4、 修改run函数,

 # gevent 实现单线程多并发

gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.cli_addr)

 


其他不用更改

 

二、select IO多路复用版本

1、 导入select模块

import select

 

2、 把socket设置为非阻塞

self.sock.setblocking(0)  

 

3、 修改run函数,用select.select()方法接收并监控多个通信socket列表

def run(self):
        while True:  # 链接循环
            # select 单进程实现同时处理请求
            inputs = [self.sock, ]
            outputs = []
            while True:
                readable, writeable, exceptional = select.select(inputs, outputs, inputs)
                for r in readable:
                    if r is self.sock:
                        request, client_address = self.sock.accept()
                        inputs.append(request)
                    else:
                        print('处理request:%s'%id(r))
                        return_code, request = TCPHandler().handle(r, )
                        if not return_code:
                            request.close()
inputs.remove(request)
print('client[%s] is disconect' % ((request.getpeername()),))

 


4、完整代码:

server.py

  1 # -*- coding: utf-8 -*-
  2 import socket
  3 import os, json, re, struct, threading, time
  4 import gevent
  5 from gevent import monkey
  6 import select
  7 from lib import commons
  8 from conf import settings
  9 from core import logger
 10 
 11 monkey.patch_all()
 12 
 13 
 14 class Server(object):
 15     def __init__(self):
 16         self.init_dir()
 17         self.sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
 18         self.sock.setblocking(0)  # select实现同时处理请求,需要设置为非阻塞
 19         # self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 20         self.sock.bind((settings.server_bind_ip, settings.server_bind_port))
 21         self.sock.listen(settings.server_listen)
 22         print("\033[42;1mserver started sucessful!\033[0m")
 23         self.run()
 24 
 25     @staticmethod
 26     def init_dir():
 27         if not os.path.exists(os.path.join(settings.base_path, 'logs')): os.mkdir(
 28             os.path.join(settings.base_path, 'logs'))
 29         if not os.path.exists(os.path.join(settings.base_path, 'db')): os.mkdir(os.path.join(settings.base_path, 'db'))
 30         if not os.path.exists(os.path.join(settings.base_path, 'home')): os.mkdir(
 31             os.path.join(settings.base_path, 'home'))
 32 
 33     def run(self):
 34         while True:  # 链接循环
 35             # select 单进程实现同时处理请求
 36             inputs = [self.sock, ]
 37             outputs = []
 38             while True:
 39                 readable, writeable, exceptional = select.select(inputs, outputs, inputs)
 40                 for r in readable:
 41                     if r is self.sock:
 42                         request, client_address = self.sock.accept()
 43                         inputs.append(request)
 44                     else:
 45                         print('处理request:%s'%id(r))
 46                         return_code, request = TCPHandler().handle(r, )
 47                         if not return_code:
 48                             request.close()
 49                             print('client[%s] is disconect' % ((request.getpeername()),))
 50                 # self.request, self.cli_addr = self.sock.accept()
 51                 # self.request.settimeout(300)
 52                 # 多线程处理请求
 53                 # thread = threading.Thread(target=TCPHandler.handle, args=(TCPHandler(), self.request, self.cli_addr))
 54                 # thread.start()
 55                 # gevent 实现单线程多并发
 56                 # gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.cli_addr)
 57 
 58 
 59 class TCPHandler(object):
 60     STATUS_CODE = {
 61         200: 'Passed authentication!',
 62         201: 'Wrong username or password!',
 63         202: 'Username does not exist!',
 64         300: 'cmd successful , the target path be returned in returnPath',
 65         301: 'cmd format error!',
 66         302: 'The path or file could not be found!',
 67         303: 'The dir is exist',
 68         304: 'The file has been downloaded or the size of the file is exceptions',
 69         305: 'Free space is not enough',
 70         401: 'File MD5 inspection failed',
 71         400: 'File MD5 inspection success',
 72     }
 73 
 74     def __init__(self):
 75         self.server_logger = logger.logger('server')
 76         self.server_logger.debug("server TCPHandler started successful!")
 77 
 78     def handle(self, request, address=(None, None)):
 79         self.request = request
 80         self.cli_addr = request.getpeername()
 81         self.server_logger.info('client[%s] is conecting' % ((request.getpeername()),))
 82         print('client[%s] is conecting' % ((request.getpeername()),))
 83         # while True:  # 通讯循环
 84         try:
 85             # 1、接收客户端的ftp命令
 86             print("waiting receive client[%s] ftp command.." % ((request.getpeername()),), id(self), self)
 87             header_dic, req_dic = self.recv_request()
 88             if not header_dic: return False, request
 89             if not header_dic['cmd']: return False, request
 90             print('receive client ftp command:%s' % header_dic['cmd'])
 91             # 2、解析ftp命令,获取相应命令参数(文件名)
 92             cmds = header_dic['cmd'].split()  # ['register',]、['get', 'a.txt']
 93             if hasattr(self, cmds[0]):
 94                 self.server_logger.info('interface:[%s], request:{client:[%s:%s] action:[%s]}' % (
 95                     cmds[0], self.cli_addr[0], self.cli_addr[1], header_dic['cmd']))
 96                 getattr(self, cmds[0])(header_dic, req_dic)
 97                 return True, request
 98         except (ConnectionResetError, ConnectionAbortedError):
 99             return False, request
100         except socket.timeout:
101             print('time out %s' % ((request.getpeername()),))
102             return False, request
103         # self.request.close()
104         # self.server_logger.info('client %s is disconect' % ((self.cli_addr,)))
105         # print('client[%s:%s] is disconect' % (self.cli_addr[0], self.cli_addr[1]))
106 
107     def unpack_header(self):
108         try:
109             pack_obj = self.request.recv(4)
110             header_size = struct.unpack('i', pack_obj)[0]
111             header_bytes = self.request.recv(header_size)
112             header_json = header_bytes.decode('utf-8')
113             header_dic = json.loads(header_json)
114             return header_dic
115         except struct.error:  # 避免客户端发送错误格式的header_size
116             return
117 
118     def unpack_info(self, info_size):
119         recv_size = 0
120         info_bytes = b''
121         while recv_size < info_size:
122             res = self.request.recv(1024)
123             info_bytes += res
124             recv_size += len(res)
125         info_json = info_bytes.decode('utf-8')
126         info_dic = json.loads(info_json)  # {'username':ton, 'password':123}
127         info_md5 = commons.getStrsMd5(info_bytes)
128         return info_dic, info_md5
129 
130     def recv_request(self):
131         header_dic = self.unpack_header()  # {'cmd':'register','info_size':0}
132         if not header_dic: return None, None
133         req_dic, info_md5 = self.unpack_info(header_dic['info_size'])
134         if header_dic.get('md5'):
135             # 校检请求内容md5一致性
136             if info_md5 == header_dic['md5']:
137                 pass
138             # print('\033[42;1m请求内容md5校检结果一致\033[0m')
139             else:
140                 pass
141                 # print('\033[31;1m请求内容md5校检结果不一致\033[0m')
142         return header_dic, req_dic
143 
144     def response(self, **kwargs):
145         rsp_info = kwargs
146         rsp_bytes = commons.getDictBytes(rsp_info)
147         md5 = commons.getStrsMd5(rsp_bytes)
148         header_size_pack, header_bytes = commons.make_header(info_size=len(rsp_bytes), md5=md5)
149         self.request.sendall(header_size_pack)
150         self.request.sendall(header_bytes)
151         self.request.sendall(rsp_bytes)
152 
153     def register(self, header_dic, req_dic):  # {'cmd':'register','info_size':0,'resultCode':0,'resultDesc':None}
154         username = req_dic['user_info']['username']
155         # 更新数据库,并制作响应信息字典
156         if not os.path.isfile(os.path.join(settings.db_file, '%s.json' % username)):
157             # 更新数据库
158             user_info = dict()
159             user_info['username'] = username
160             user_info['password'] = req_dic['user_info']['password']
161             user_info['home'] = os.path.join(settings.user_home_dir, username)
162             user_info['quota'] = settings.user_quota * (1024 * 1024)
163             commons.save_to_file(user_info, os.path.join(settings.db_file, '%s.json' % username))
164             resultCode = 0
165             resultDesc = None
166             # 创建家目录
167             if not os.path.exists(os.path.join(settings.user_home_dir, username)):
168                 os.mkdir(os.path.join(settings.user_home_dir, username))
169             self.server_logger.info('client[%s:%s] 注册用户[%s]成功' % (self.cli_addr[0], self.cli_addr[1], username))
170         else:
171             resultCode = 1
172             resultDesc = '该用户已存在,注册失败'
173             self.server_logger.warning('client[%s:%s] 注册用户[%s]失败:%s' % (self.cli_addr[0], self.cli_addr[1],
174                                                                         username, resultDesc))
175         # 响应客户端注册请求
176         self.response(resultCode=resultCode, resultDesc=resultDesc)
177 
178     @staticmethod
179     def auth(req_dic):
180         # print(req_dic['user_info'])
181         user_info = None
182         status_code = 201
183         try:
184             req_username = req_dic['user_info']['username']
185             db_file = os.path.join(settings.db_file, '%s.json' % req_username)
186             # 验证用户名密码,并制作响应信息字典
187             if not os.path.isfile(db_file):
188                 status_code = 202
189             else:
190                 with open(db_file, 'r') as f:
191                     user_info_db = json.load(f)
192                 if user_info_db['password'] == req_dic['user_info']['password']:
193                     status_code = 200
194                     user_info = user_info_db
195             return status_code, user_info
196         # 捕获  客户端鉴权请求时发送一个空字典或错误的字典  的异常
197         except KeyError:
198             return 201, user_info
199 
200     def login(self, header_dic, req_dic):
201         # 鉴权
202         status_code, user_info = self.auth(req_dic)
203         # 响应客户端登陆请求
204         self.response(user_info=user_info, resultCode=status_code)
205 
206     def query_quota(self, header_dic, req_dic):
207         used_quota = None
208         total_quota = None
209         # 鉴权
210         status_code, user_info = self.auth(req_dic)
211         # 查询配额
212         if status_code == 200:
213             used_quota = commons.getFileSize(user_info['home'])
214             total_quota = user_info['quota']
215         # 响应客户端配额查询请求
216         self.response(resultCode=status_code, total_quota=total_quota, used_quota=used_quota)
217 
218     @staticmethod
219     def parse_file_path(req_path, cur_path):
220         req_path = req_path.replace(r'/', '\\')
221         req_path = req_path.replace(r'//', r'/', )
222         req_path = req_path.replace('\\\\', '\\')
223         req_path = req_path.replace('~\\', '', 1)
224         req_path = req_path.replace(r'~', '', 1)
225         req_paths = re.findall(r'[^\\]+', req_path)
226         cur_paths = re.findall(r'[^\\]+', cur_path)
227         cur_paths.extend(req_paths)
228         cur_paths[0] += '\\'
229         while '.' in cur_paths:
230             cur_paths.remove('.')
231         while '..' in cur_paths:
232             for index, item in enumerate(cur_paths):
233                 if item == '..':
234                     cur_paths.pop(index)
235                     cur_paths.pop(index - 1)
236                     break
237         return cur_paths
238 
239     def cd(self, header_dic, req_dic):
240         cmds = header_dic['cmd'].split()
241         # 鉴权
242         status_code, user_info = self.auth(req_dic)
243         home = os.path.join(settings.user_home_dir, user_info['username'])
244         # 先定义响应信息
245         returnPath = req_dic['user_info']['cur_path']
246         if status_code == 200:
247             if len(cmds) != 1:
248                 # 解析cd的真实路径
249                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
250                 cd_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
251                 print('cd解析后的路径:', cd_path)
252                 if os.path.isdir(cd_path):
253                     if home in cd_path:
254                         resultCode = 300
255                         returnPath = cd_path.replace('%s\\' % settings.user_home_dir, '', 1)
256                     else:
257                         resultCode = 302
258                 else:
259                     resultCode = 302
260             else:
261                 resultCode = 301
262         else:
263             resultCode = 201
264         # 响应客户端的cd命令结果
265         print('cd发送给客户端的路径:', returnPath)
266         self.response(resultCode=resultCode, returnPath=returnPath)
267 
268     def ls(self, header_dic, req_dic):
269         cmds = header_dic['cmd'].split()
270         # 鉴权
271         status_code, user_info = self.auth(req_dic)
272         home = os.path.join(settings.user_home_dir, user_info['username'])
273         # 先定义响应信息
274         returnFilenames = None
275         if status_code == 200:
276             if len(cmds) <= 2:
277                 # 解析ls的真实路径
278                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
279                 if len(cmds) == 2:
280                     ls_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
281                 else:
282                     ls_path = cur_path
283                 print('ls解析后的路径:', ls_path)
284                 if os.path.isdir(ls_path):
285                     if home in ls_path:
286                         returnCode, filenames = commons.getFile(ls_path, home)
287                         resultCode = 300
288                         returnFilenames = filenames
289                     else:
290                         resultCode = 302
291                 else:
292                     resultCode = 302
293             else:
294                 resultCode = 301
295         else:
296             resultCode = 201
297         # 响应客户端的ls命令结果
298         time.sleep(5)
299         self.response(resultCode=resultCode, returnFilenames=returnFilenames)
300 
301     def rm(self, header_dic, req_dic):
302         cmds = header_dic['cmd'].split()
303         # 鉴权
304         status_code, user_info = self.auth(req_dic)
305         home = os.path.join(settings.user_home_dir, user_info['username'])
306         # 先定义响应信息
307         if status_code == 200:
308             if len(cmds) == 2:
309                 # 解析rm的真实路径
310                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
311                 rm_path = os.path.join('', *self.parse_file_path(os.path.dirname(cmds[1]), cur_path))
312                 rm_file = os.path.join(rm_path, os.path.basename(cmds[1]))
313                 print('rm解析后的文件或文件夹:', rm_file)
314                 if os.path.exists(rm_file):
315                     if home in rm_file:
316                         commons.rmdirs(rm_file)
317                         resultCode = 300
318                     else:
319                         resultCode = 302
320                 else:
321                     resultCode = 302
322             else:
323                 resultCode = 301
324         else:
325             resultCode = 201
326         # 响应客户端的rm命令结果
327         self.response(resultCode=resultCode)
328 
329     def mkdir(self, header_dic, req_dic):
330         cmds = header_dic['cmd'].split()
331         # 鉴权
332         status_code, user_info = self.auth(req_dic)
333         home = os.path.join(settings.user_home_dir, user_info['username'])
334         # 先定义响应信息
335         if status_code == 200:
336             if len(cmds) == 2:
337                 # 解析rm的真实路径
338                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
339                 mkdir_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
340                 print('mkdir解析后的文件夹:', mkdir_path)
341                 if not os.path.isdir(mkdir_path):
342                     if home in mkdir_path:
343                         os.makedirs(mkdir_path)
344                         resultCode = 300
345                     else:
346                         resultCode = 302
347                 else:
348                     resultCode = 303
349             else:
350                 resultCode = 301
351         else:
352             resultCode = 201
353         # 响应客户端的mkdir命令结果
354         self.response(resultCode=resultCode)
355 
356     def get(self, header_dic, req_dic):
357         """客户端下载文件"""
358         cmds = header_dic['cmd'].split()  # ['get', 'a.txt', 'download']
359         get_file = None
360         # 鉴权
361         status_code, user_info = self.auth(req_dic)
362         home = os.path.join(settings.user_home_dir, user_info['username'])
363         # 解析断点续传信息
364         position = 0
365         if req_dic['resume'] and isinstance(req_dic['position'], int):
366             position = req_dic['position']
367         # 先定义响应信息
368         resultCode = 300
369         FileSize = None
370         FileMd5 = None
371         if status_code == 200:
372             if 1 < len(cmds) < 4:
373                 # 解析需要get文件的真实路径
374                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
375                 get_file = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
376                 print('get解析后的路径:', get_file)
377                 if os.path.isfile(get_file):
378                     if home in get_file:
379                         FileSize = commons.getFileSize(get_file)
380                         if position >= FileSize != 0:
381                             resultCode = 304
382                         else:
383                             resultCode = 300
384                             FileSize = FileSize
385                             FileMd5 = commons.getFileMd5(get_file)
386                     else:
387                         resultCode = 302
388                 else:
389                     resultCode = 302
390             else:
391                 resultCode = 301
392         else:
393             resultCode = 201
394         # 响应客户端的get命令结果
395         self.response(resultCode=resultCode, FileSize=FileSize, FileMd5=FileMd5)
396         if resultCode == 300:
397             # 发送文件数据
398             with open(get_file, 'rb') as f:
399                 f.seek(position)
400                 for line in f:
401                     self.request.send(line)
402 
403     def put(self, header_dic, req_dic):
404         cmds = header_dic['cmd'].split()  # ['put', 'download/a.txt', 'video']
405         put_file = None
406         # 鉴权
407         status_code, user_info = self.auth(req_dic)
408         home = os.path.join(settings.user_home_dir, user_info['username'])
409         # 查询配额
410         used_quota = commons.getFileSize(user_info['home'])
411         total_quota = user_info['quota']
412         # 先定义响应信息
413         if status_code == 200:
414             if 1 < len(cmds) < 4:
415                 # 解析需要put文件的真实路径
416                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
417                 if len(cmds) == 3:
418                     put_file = os.path.join(os.path.join('', *self.parse_file_path(cmds[2], cur_path)),
419                                             os.path.basename(cmds[1]))
420                 else:
421                     put_file = os.path.join(cur_path, os.path.basename(cmds[1]))
422                 print('put解析后的文件:', put_file)
423                 put_path = os.path.dirname(put_file)
424                 if os.path.isdir(put_path):
425                     if home in put_path:
426                         if (req_dic['FileSize'] + used_quota) <= total_quota:
427                             resultCode = 300
428                         else:
429                             resultCode = 305
430                     else:
431                         resultCode = 302
432                 else:
433                     resultCode = 302
434             else:
435                 resultCode = 301
436         else:
437             resultCode = 201
438         # 响应客户端的put命令结果
439         self.response(resultCode=resultCode)
440         if resultCode == 300:
441             # 接收文件数据,写入文件
442             recv_size = 0
443             with open(put_file, 'wb') as f:
444                 while recv_size < req_dic['FileSize']:
445                     file_data = self.request.recv(1024)
446                     f.write(file_data)
447                     recv_size += len(file_data)
448             # 校检文件md5一致性
449             if commons.getFileMd5(put_file) == req_dic['FileMd5']:
450                 resultCode = 400
451                 print('\033[42;1m文件md5校检结果一致\033[0m')
452                 print('\033[42;1m文件上传成功,大小:%d,文件名:%s\033[0m' % (req_dic['FileSize'], put_file))
453             else:
454                 os.remove(put_file)
455                 resultCode = 401
456                 print('\033[31;1m文件md5校检结果不一致\033[0m')
457                 print('\033[42;1m文件上传失败\033[0m')
458             # 返回上传文件是否成功响应
459             self.response(resultCode=resultCode)
server.py

 

三、socketserver模块(内部采用seletors模块)实现并发效果(Linux支持epoll模型)

1、导入 socketserver模块

import socketserver

 

2、TCPHandler类继承socketsever.BaseRequestHandler

class TCPHandler(socketserver.BaseRequestHandler):

 

3、必须重写handle函数

4、完整代码:

server.py

  1 # -*- coding: utf-8 -*-
  2 # from gevent import monkey;monkey.patch_all()
  3 import socketserver
  4 import socket
  5 import os, json, re, struct, threading, time
  6 import gevent
  7 import select
  8 from lib import commons
  9 from conf import settings
 10 from core import logger
 11 
 12 class Server(object):
 13     def __init__(self):
 14         self.init_dir()
 15         # self.sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
 16         # # self.sock.setblocking(0)  # select实现同时处理请求,需要设置为非阻塞
 17         # # self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 18         # self.sock.bind((settings.server_bind_ip, settings.server_bind_port))
 19         # self.sock.listen(settings.server_listen)
 20         print("\033[42;1mserver started sucessful!\033[0m")
 21         self.run()
 22 
 23     @staticmethod
 24     def init_dir():
 25         if not os.path.exists(os.path.join(settings.base_path, 'logs')): os.mkdir(
 26             os.path.join(settings.base_path, 'logs'))
 27         if not os.path.exists(os.path.join(settings.base_path, 'db')): os.mkdir(os.path.join(settings.base_path, 'db'))
 28         if not os.path.exists(os.path.join(settings.base_path, 'home')): os.mkdir(
 29             os.path.join(settings.base_path, 'home'))
 30 
 31     def run(self):
 32         server = socketserver.ThreadingTCPServer((settings.server_bind_ip, settings.server_bind_port), TCPHandler)
 33         server.allow_reuse_address = True
 34         server.serve_forever()
 35         # while True:  # 链接循环
 36             # # select 单进程实现同时处理请求
 37             # inputs = [self.sock, ]
 38             # outputs = []
 39             # while True:
 40             #     readable, writeable, exceptional = select.select(inputs, outputs, inputs)
 41             #     for r in readable:
 42             #         if r is self.sock:
 43             #             request, client_address = self.sock.accept()
 44             #             inputs.append(request)
 45             #         else:
 46             #             print('处理request:%s'%id(r))
 47             #             return_code, request = TCPHandler().handle(r, )
 48             #             if not return_code:
 49             #                 request.close()
 50             #                 print('client[%s] is disconect' % ((request.getpeername()),))
 51             # self.request, self.client_address = self.sock.accept()
 52             # self.request.settimeout(300)
 53             # # 多线程处理请求
 54             # thread = threading.Thread(target=TCPHandler.handle, args=(TCPHandler(), self.request, self.client_address))
 55             # thread.start()
 56             # # gevent 实现单线程多并发
 57             # # gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.client_address)
 58 
 59 
 60 class TCPHandler(socketserver.BaseRequestHandler):
 61     STATUS_CODE = {
 62         200: 'Passed authentication!',
 63         201: 'Wrong username or password!',
 64         202: 'Username does not exist!',
 65         300: 'cmd successful , the target path be returned in returnPath',
 66         301: 'cmd format error!',
 67         302: 'The path or file could not be found!',
 68         303: 'The dir is exist',
 69         304: 'The file has been downloaded or the size of the file is exceptions',
 70         305: 'Free space is not enough',
 71         401: 'File MD5 inspection failed',
 72         400: 'File MD5 inspection success',
 73     }
 74 
 75     def __init__(self, request, client_address, server):
 76         self.server_logger = logger.logger('server')
 77         self.server_logger.debug("server TCPHandler started successful!")
 78         super().__init__(request, client_address, server)
 79 
 80     def handle(self):
 81         # self.request = request
 82         # self.client_address = self.client_address
 83         self.server_logger.info('client[%s] is conecting' % ((self.client_address,)))
 84         print('client[%s] is conecting' % ((self.client_address,)))
 85         while True:  # 通讯循环
 86             try:
 87                 # 1、接收客户端的ftp命令
 88                 print("waiting receive client[%s] ftp command.." % ((self.client_address,)), id(self), self)
 89                 header_dic, req_dic = self.recv_request()
 90                 if not header_dic:break
 91                 if not header_dic['cmd']: break
 92                 print('receive client ftp command:%s' % header_dic['cmd'])
 93                 # 2、解析ftp命令,获取相应命令参数(文件名)
 94                 cmds = header_dic['cmd'].split()  # ['register',]、['get', 'a.txt']
 95                 if hasattr(self, cmds[0]):
 96                     self.server_logger.info('interface:[%s], request:{client:[%s:%s] action:[%s]}' % (
 97                         cmds[0], self.client_address[0], self.client_address[1], header_dic['cmd']))
 98                     getattr(self, cmds[0])(header_dic, req_dic)
 99             except (ConnectionResetError, ConnectionAbortedError):break
100             except socket.timeout:
101                 print('time out %s' % ((self.client_address,)))
102                 break
103         self.request.close()
104         self.server_logger.info('client %s is disconect' % ((self.client_address,)))
105         print('client[%s:%s] is disconect' % (self.client_address[0], self.client_address[1]))
106 
107     def unpack_header(self):
108         try:
109             pack_obj = self.request.recv(4)
110             header_size = struct.unpack('i', pack_obj)[0]
111             header_bytes = self.request.recv(header_size)
112             header_json = header_bytes.decode('utf-8')
113             header_dic = json.loads(header_json)
114             return header_dic
115         except struct.error:  # 避免客户端发送错误格式的header_size
116             return
117 
118     def unpack_info(self, info_size):
119         recv_size = 0
120         info_bytes = b''
121         while recv_size < info_size:
122             res = self.request.recv(1024)
123             info_bytes += res
124             recv_size += len(res)
125         info_json = info_bytes.decode('utf-8')
126         info_dic = json.loads(info_json)  # {'username':ton, 'password':123}
127         info_md5 = commons.getStrsMd5(info_bytes)
128         return info_dic, info_md5
129 
130     def recv_request(self):
131         header_dic = self.unpack_header()  # {'cmd':'register','info_size':0}
132         if not header_dic: return None, None
133         req_dic, info_md5 = self.unpack_info(header_dic['info_size'])
134         if header_dic.get('md5'):
135             # 校检请求内容md5一致性
136             if info_md5 == header_dic['md5']:
137                 pass
138             # print('\033[42;1m请求内容md5校检结果一致\033[0m')
139             else:
140                 pass
141                 # print('\033[31;1m请求内容md5校检结果不一致\033[0m')
142         return header_dic, req_dic
143 
144     def response(self, **kwargs):
145         rsp_info = kwargs
146         rsp_bytes = commons.getDictBytes(rsp_info)
147         md5 = commons.getStrsMd5(rsp_bytes)
148         header_size_pack, header_bytes = commons.make_header(info_size=len(rsp_bytes), md5=md5)
149         self.request.sendall(header_size_pack)
150         self.request.sendall(header_bytes)
151         self.request.sendall(rsp_bytes)
152 
153     def register(self, header_dic, req_dic):  # {'cmd':'register','info_size':0,'resultCode':0,'resultDesc':None}
154         username = req_dic['user_info']['username']
155         # 更新数据库,并制作响应信息字典
156         if not os.path.isfile(os.path.join(settings.db_file, '%s.json' % username)):
157             # 更新数据库
158             user_info = dict()
159             user_info['username'] = username
160             user_info['password'] = req_dic['user_info']['password']
161             user_info['home'] = os.path.join(settings.user_home_dir, username)
162             user_info['quota'] = settings.user_quota * (1024 * 1024)
163             commons.save_to_file(user_info, os.path.join(settings.db_file, '%s.json' % username))
164             resultCode = 0
165             resultDesc = None
166             # 创建家目录
167             if not os.path.exists(os.path.join(settings.user_home_dir, username)):
168                 os.mkdir(os.path.join(settings.user_home_dir, username))
169             self.server_logger.info('client[%s:%s] 注册用户[%s]成功' % (self.client_address[0], self.client_address[1], username))
170         else:
171             resultCode = 1
172             resultDesc = '该用户已存在,注册失败'
173             self.server_logger.warning('client[%s:%s] 注册用户[%s]失败:%s' % (self.client_address[0], self.client_address[1],
174                                                                         username, resultDesc))
175         # 响应客户端注册请求
176         self.response(resultCode=resultCode, resultDesc=resultDesc)
177 
178     @staticmethod
179     def auth(req_dic):
180         # print(req_dic['user_info'])
181         user_info = None
182         status_code = 201
183         try:
184             req_username = req_dic['user_info']['username']
185             db_file = os.path.join(settings.db_file, '%s.json' % req_username)
186             # 验证用户名密码,并制作响应信息字典
187             if not os.path.isfile(db_file):
188                 status_code = 202
189             else:
190                 with open(db_file, 'r') as f:
191                     user_info_db = json.load(f)
192                 if user_info_db['password'] == req_dic['user_info']['password']:
193                     status_code = 200
194                     user_info = user_info_db
195             return status_code, user_info
196         # 捕获  客户端鉴权请求时发送一个空字典或错误的字典  的异常
197         except KeyError:
198             return 201, user_info
199 
200     def login(self, header_dic, req_dic):
201         # 鉴权
202         status_code, user_info = self.auth(req_dic)
203         # 响应客户端登陆请求
204         self.response(user_info=user_info, resultCode=status_code)
205 
206     def query_quota(self, header_dic, req_dic):
207         used_quota = None
208         total_quota = None
209         # 鉴权
210         status_code, user_info = self.auth(req_dic)
211         # 查询配额
212         if status_code == 200:
213             used_quota = commons.getFileSize(user_info['home'])
214             total_quota = user_info['quota']
215         # 响应客户端配额查询请求
216         self.response(resultCode=status_code, total_quota=total_quota, used_quota=used_quota)
217 
218     @staticmethod
219     def parse_file_path(req_path, cur_path):
220         req_path = req_path.replace(r'/', '\\')
221         req_path = req_path.replace(r'//', r'/', )
222         req_path = req_path.replace('\\\\', '\\')
223         req_path = req_path.replace('~\\', '', 1)
224         req_path = req_path.replace(r'~', '', 1)
225         req_paths = re.findall(r'[^\\]+', req_path)
226         cur_paths = re.findall(r'[^\\]+', cur_path)
227         cur_paths.extend(req_paths)
228         cur_paths[0] += '\\'
229         while '.' in cur_paths:
230             cur_paths.remove('.')
231         while '..' in cur_paths:
232             for index, item in enumerate(cur_paths):
233                 if item == '..':
234                     cur_paths.pop(index)
235                     cur_paths.pop(index - 1)
236                     break
237         return cur_paths
238 
239     def cd(self, header_dic, req_dic):
240         cmds = header_dic['cmd'].split()
241         # 鉴权
242         status_code, user_info = self.auth(req_dic)
243         home = os.path.join(settings.user_home_dir, user_info['username'])
244         # 先定义响应信息
245         returnPath = req_dic['user_info']['cur_path']
246         if status_code == 200:
247             if len(cmds) != 1:
248                 # 解析cd的真实路径
249                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
250                 cd_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
251                 print('cd解析后的路径:', cd_path)
252                 if os.path.isdir(cd_path):
253                     if home in cd_path:
254                         resultCode = 300
255                         returnPath = cd_path.replace('%s\\' % settings.user_home_dir, '', 1)
256                     else:
257                         resultCode = 302
258                 else:
259                     resultCode = 302
260             else:
261                 resultCode = 301
262         else:
263             resultCode = 201
264         # 响应客户端的cd命令结果
265         print('cd发送给客户端的路径:', returnPath)
266         self.response(resultCode=resultCode, returnPath=returnPath)
267 
268     def ls(self, header_dic, req_dic):
269         cmds = header_dic['cmd'].split()
270         # 鉴权
271         status_code, user_info = self.auth(req_dic)
272         home = os.path.join(settings.user_home_dir, user_info['username'])
273         # 先定义响应信息
274         returnFilenames = None
275         if status_code == 200:
276             if len(cmds) <= 2:
277                 # 解析ls的真实路径
278                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
279                 if len(cmds) == 2:
280                     ls_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
281                 else:
282                     ls_path = cur_path
283                 print('ls解析后的路径:', ls_path)
284                 if os.path.isdir(ls_path):
285                     if home in ls_path:
286                         returnCode, filenames = commons.getFile(ls_path, home)
287                         resultCode = 300
288                         returnFilenames = filenames
289                     else:
290                         resultCode = 302
291                 else:
292                     resultCode = 302
293             else:
294                 resultCode = 301
295         else:
296             resultCode = 201
297         # 响应客户端的ls命令结果
298         self.response(resultCode=resultCode, returnFilenames=returnFilenames)
299 
300     def rm(self, header_dic, req_dic):
301         cmds = header_dic['cmd'].split()
302         # 鉴权
303         status_code, user_info = self.auth(req_dic)
304         home = os.path.join(settings.user_home_dir, user_info['username'])
305         # 先定义响应信息
306         if status_code == 200:
307             if len(cmds) == 2:
308                 # 解析rm的真实路径
309                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
310                 rm_path = os.path.join('', *self.parse_file_path(os.path.dirname(cmds[1]), cur_path))
311                 rm_file = os.path.join(rm_path, os.path.basename(cmds[1]))
312                 print('rm解析后的文件或文件夹:', rm_file)
313                 if os.path.exists(rm_file):
314                     if home in rm_file:
315                         commons.rmdirs(rm_file)
316                         resultCode = 300
317                     else:
318                         resultCode = 302
319                 else:
320                     resultCode = 302
321             else:
322                 resultCode = 301
323         else:
324             resultCode = 201
325         # 响应客户端的rm命令结果
326         self.response(resultCode=resultCode)
327 
328     def mkdir(self, header_dic, req_dic):
329         cmds = header_dic['cmd'].split()
330         # 鉴权
331         status_code, user_info = self.auth(req_dic)
332         home = os.path.join(settings.user_home_dir, user_info['username'])
333         # 先定义响应信息
334         if status_code == 200:
335             if len(cmds) == 2:
336                 # 解析rm的真实路径
337                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
338                 mkdir_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
339                 print('mkdir解析后的文件夹:', mkdir_path)
340                 if not os.path.isdir(mkdir_path):
341                     if home in mkdir_path:
342                         os.makedirs(mkdir_path)
343                         resultCode = 300
344                     else:
345                         resultCode = 302
346                 else:
347                     resultCode = 303
348             else:
349                 resultCode = 301
350         else:
351             resultCode = 201
352         # 响应客户端的mkdir命令结果
353         self.response(resultCode=resultCode)
354 
355     def get(self, header_dic, req_dic):
356         """客户端下载文件"""
357         cmds = header_dic['cmd'].split()  # ['get', 'a.txt', 'download']
358         get_file = None
359         # 鉴权
360         status_code, user_info = self.auth(req_dic)
361         home = os.path.join(settings.user_home_dir, user_info['username'])
362         # 解析断点续传信息
363         position = 0
364         if req_dic['resume'] and isinstance(req_dic['position'], int):
365             position = req_dic['position']
366         # 先定义响应信息
367         resultCode = 300
368         FileSize = None
369         FileMd5 = None
370         if status_code == 200:
371             if 1 < len(cmds) < 4:
372                 # 解析需要get文件的真实路径
373                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
374                 get_file = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
375                 print('get解析后的路径:', get_file)
376                 if os.path.isfile(get_file):
377                     if home in get_file:
378                         FileSize = commons.getFileSize(get_file)
379                         if position >= FileSize != 0:
380                             resultCode = 304
381                         else:
382                             resultCode = 300
383                             FileSize = FileSize
384                             FileMd5 = commons.getFileMd5(get_file)
385                     else:
386                         resultCode = 302
387                 else:
388                     resultCode = 302
389             else:
390                 resultCode = 301
391         else:
392             resultCode = 201
393         # 响应客户端的get命令结果
394         self.response(resultCode=resultCode, FileSize=FileSize, FileMd5=FileMd5)
395         if resultCode == 300:
396             # 发送文件数据
397             with open(get_file, 'rb') as f:
398                 f.seek(position)
399                 for line in f:
400                     self.request.send(line)
401 
402     def put(self, header_dic, req_dic):
403         cmds = header_dic['cmd'].split()  # ['put', 'download/a.txt', 'video']
404         put_file = None
405         # 鉴权
406         status_code, user_info = self.auth(req_dic)
407         home = os.path.join(settings.user_home_dir, user_info['username'])
408         # 查询配额
409         used_quota = commons.getFileSize(user_info['home'])
410         total_quota = user_info['quota']
411         # 先定义响应信息
412         if status_code == 200:
413             if 1 < len(cmds) < 4:
414                 # 解析需要put文件的真实路径
415                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
416                 if len(cmds) == 3:
417                     put_file = os.path.join(os.path.join('', *self.parse_file_path(cmds[2], cur_path)),
418                                             os.path.basename(cmds[1]))
419                 else:
420                     put_file = os.path.join(cur_path, os.path.basename(cmds[1]))
421                 print('put解析后的文件:', put_file)
422                 put_path = os.path.dirname(put_file)
423                 if os.path.isdir(put_path):
424                     if home in put_path:
425                         if (req_dic['FileSize'] + used_quota) <= total_quota:
426                             resultCode = 300
427                         else:
428                             resultCode = 305
429                     else:
430                         resultCode = 302
431                 else:
432                     resultCode = 302
433             else:
434                 resultCode = 301
435         else:
436             resultCode = 201
437         # 响应客户端的put命令结果
438         self.response(resultCode=resultCode)
439         if resultCode == 300:
440             # 接收文件数据,写入文件
441             recv_size = 0
442             with open(put_file, 'wb') as f:
443                 while recv_size < req_dic['FileSize']:
444                     file_data = self.request.recv(1024)
445                     f.write(file_data)
446                     recv_size += len(file_data)
447             # 校检文件md5一致性
448             if commons.getFileMd5(put_file) == req_dic['FileMd5']:
449                 resultCode = 400
450                 print('\033[42;1m文件md5校检结果一致\033[0m')
451                 print('\033[42;1m文件上传成功,大小:%d,文件名:%s\033[0m' % (req_dic['FileSize'], put_file))
452             else:
453                 os.remove(put_file)
454                 resultCode = 401
455                 print('\033[31;1m文件md5校检结果不一致\033[0m')
456                 print('\033[42;1m文件上传失败\033[0m')
457             # 返回上传文件是否成功响应
458             self.response(resultCode=resultCode)
server.py

PS:记得把from gevent import monkey;monkey.patch_all() 注释掉

 

四、seletors模块实现单线程并发效果

1、导入selectors模块

import selectors

 

2、把accept接受新客户端请求和recv数据分别写成两个函数

def accept(self, server):
        """接受新的client请求"""
        request, client_address = server.accept()
        request.setblocking(False)
        print('client%s is conecting' % ((client_address,)))
        self.sel.register(request, selectors.EVENT_READ, self.handle)
        self.dic[request] = {}

def handle(self, request):
        self.request = request
        self.client_address = request.getpeername()
        try:
            if not self.dic[request]:
                # 1、接收客户端的ftp命令
                header_dic, req_dic = self.recv_request()
                if not header_dic or not header_dic['cmd']:
                    print('client[%s] is disconect' % ((request.getpeername()),))
                    request.close()
                    self.sel.unregister(request)
                print( 'receive client ftp command:%s' % header_dic['cmd'])
            else:
                header_dic = self.dic[request]['header_dic']
                req_dic = self.dic[request]['req_dic']
            # 2、解析ftp命令,获取相应命令参数(文件名)
            cmds = header_dic['cmd'].split()  # ['register',]、['get', 'a.txt']
            if not self.dic[request]:
                if hasattr(self, cmds[0]):
                    getattr(self, cmds[0])(header_dic, req_dic)
            else:
                getattr(self, cmds[0])(self.dic[request]['header_dic'], self.dic[request]['req_dic'])
        except BlockingIOError as e:
            pass
        except (ConnectionResetError, ConnectionAbortedError, socket.timeout) as e:
            print('error: ',e)
            print('client[%s] is disconect' % ((request.getpeername()),))
            request.close()
            self.sel.unregister(request)

3、检测所有的fileobj,是否有完成wait data的

def run(self):
        while True:
            events = self.sel.select() # 检测所有注册的socket, 是否有完成wait data的
            for sel_obj, mask in events:
                callback = sel_obj.data # callback = accept
                callback(sel_obj.fileobj,)

 

4、完整代码:

  1 # -*- coding: utf-8 -*-
  2 # from gevent import monkey;monkey.patch_all()
  3 # import socketserver
  4 import selectors
  5 import socket
  6 import os, json, re, struct, threading, time
  7 # import gevent
  8 # import select
  9 from lib import commons
 10 from conf import settings
 11 from core import logger
 12 
 13 class Server(object):
 14     def __init__(self):
 15         self.init_dir()
 16         self.server_logger = logger.logger('server')
 17         self.sel = selectors.DefaultSelector()
 18         self.dic = {} # 记录文件传输未完成的状态
 19         self.create_socket()
 20         self.run()
 21 
 22     @staticmethod
 23     def init_dir():
 24         if not os.path.exists(os.path.join(settings.base_path, 'logs')): os.mkdir(
 25             os.path.join(settings.base_path, 'logs'))
 26         if not os.path.exists(os.path.join(settings.base_path, 'db')): os.mkdir(os.path.join(settings.base_path, 'db'))
 27         if not os.path.exists(os.path.join(settings.base_path, 'home')): os.mkdir(
 28             os.path.join(settings.base_path, 'home'))
 29 
 30     def create_socket(self):
 31         self.server = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
 32         self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 33         self.server.bind((settings.server_bind_ip, settings.server_bind_port))
 34         self.server.listen(settings.server_listen)
 35         self.server.setblocking(False)  # 设置为非阻塞
 36         self.sel.register(self.server, selectors.EVENT_READ, self.accept)
 37         print("\033[42;1mserver started sucessful!\033[0m")
 38 
 39     def accept(self, server):
 40         """接受新的client请求"""
 41         request, client_address = server.accept()
 42         request.setblocking(False)
 43         print('client%s is conecting' % ((client_address,)))
 44         self.sel.register(request, selectors.EVENT_READ, self.handle)
 45         self.dic[request] = {}
 46 
 47     def run(self):
 48         while True:
 49             events = self.sel.select() # 检测所有注册的socket, 是否有完成wait data的
 50             for sel_obj, mask in events:
 51                 callback = sel_obj.data # callback = accept
 52                 callback(sel_obj.fileobj,)
 53         # socketserver模块实现多并发
 54         # server = socketserver.ThreadingTCPServer((settings.server_bind_ip, settings.server_bind_port), TCPHandler)
 55         # server.allow_reuse_address = True
 56         # server.serve_forever()
 57         # while True:  # 链接循环
 58             # # select 单进程实现同时处理请求
 59             # inputs = [self.server, ]
 60             # outputs = []
 61             # while True:
 62             #     readable, writeable, exceptional = select.select(inputs, outputs, inputs)
 63             #     for r in readable:
 64             #         if r is self.server:
 65             #             request, client_address = self.server.accept()
 66             #             inputs.append(request)
 67             #         else:
 68             #             print('处理request:%s'%id(r))
 69             #             return_code, request = TCPHandler().handle(r, )
 70             #             if not return_code:
 71             #                 request.close()
 72             #                 print('client[%s] is disconect' % ((request.getpeername()),))
 73             # self.request, self.client_address = self.server.accept()
 74             # self.request.settimeout(300)
 75             # # 多线程处理请求
 76             # thread = threading.Thread(target=TCPHandler.handle, args=(TCPHandler(), self.request, self.client_address))
 77             # thread.start()
 78             # # gevent 实现单线程多并发
 79             # # gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.client_address)
 80 # class TCPHandler():
 81     STATUS_CODE = {
 82         200: 'Passed authentication!',
 83         201: 'Wrong username or password!',
 84         202: 'Username does not exist!',
 85         300: 'cmd successful , the target path be returned in returnPath',
 86         301: 'cmd format error!',
 87         302: 'The path or file could not be found!',
 88         303: 'The dir is exist',
 89         304: 'The file has been downloaded or the size of the file is exceptions',
 90         305: 'Free space is not enough',
 91         401: 'File MD5 inspection failed',
 92         400: 'File MD5 inspection success',
 93     }
 94 
 95     # def __init__(self, request, client_address):
 96         # self.server_logger = logger.logger('server')
 97         # self.server_logger.debug("server TCPHandler started successful!")
 98         # self.request = request
 99         # self.client_address = client_address
100         # super().__init__(request, client_address, server)
101 
102     def handle(self, request):
103         self.request = request
104         self.client_address = request.getpeername()
105         try:
106             if not self.dic[request]:
107                 # 1、接收客户端的ftp命令
108                 header_dic, req_dic = self.recv_request()
109                 if not header_dic or not header_dic['cmd']:
110                     print('client[%s] is disconect' % ((request.getpeername()),))
111                     request.close()
112                     self.sel.unregister(request)
113                 print( 'receive client ftp command:%s' % header_dic['cmd'])
114             else:
115                 header_dic = self.dic[request]['header_dic']
116                 req_dic = self.dic[request]['req_dic']
117             # 2、解析ftp命令,获取相应命令参数(文件名)
118             cmds = header_dic['cmd'].split()  # ['register',]、['get', 'a.txt']
119             if not self.dic[request]:
120                 if hasattr(self, cmds[0]):
121                     getattr(self, cmds[0])(header_dic, req_dic)
122             else:
123                 getattr(self, cmds[0])(self.dic[request]['header_dic'], self.dic[request]['req_dic'])
124         except BlockingIOError as e:
125             pass
126         except (ConnectionResetError, ConnectionAbortedError, socket.timeout) as e:
127             print('error: ',e)
128             print('client[%s] is disconect' % ((request.getpeername()),))
129             request.close()
130             self.sel.unregister(request)
131         # self.request.close()
132         # self.server_logger.info('client %s is disconect' % ((self.client_address,)))
133         # print('client[%s:%s] is disconect' % (self.client_address[0], self.client_address[1]))
134 
135     def unpack_header(self):
136         try:
137             pack_obj = self.request.recv(4)
138             header_size = struct.unpack('i', pack_obj)[0]
139             time.sleep(1/(10**20))
140             header_bytes = self.request.recv(header_size)
141             header_json = header_bytes.decode('utf-8')
142             header_dic = json.loads(header_json)
143             return header_dic
144         except struct.error:  # 避免客户端发送错误格式的header_size
145             return
146 
147     def unpack_info(self, info_size):
148         recv_size = 0
149         info_bytes = b''
150         while recv_size < info_size:
151             res = self.request.recv(1024)
152             info_bytes += res
153             recv_size += len(res)
154         info_json = info_bytes.decode('utf-8')
155         info_dic = json.loads(info_json)  # {'username':ton, 'password':123}
156         info_md5 = commons.getStrsMd5(info_bytes)
157         return info_dic, info_md5
158 
159     def recv_request(self):
160         header_dic = self.unpack_header()  # {'cmd':'register','info_size':0}
161         if not header_dic: return None, None
162         req_dic, info_md5 = self.unpack_info(header_dic['info_size'])
163         if header_dic.get('md5'):
164             # 校检请求内容md5一致性
165             if info_md5 == header_dic['md5']:
166                 pass
167             # print('\033[42;1m请求内容md5校检结果一致\033[0m')
168             else:
169                 pass
170                 # print('\033[31;1m请求内容md5校检结果不一致\033[0m')
171         return header_dic, req_dic
172 
173     def response(self, **kwargs):
174         rsp_info = kwargs
175         rsp_bytes = commons.getDictBytes(rsp_info)
176         md5 = commons.getStrsMd5(rsp_bytes)
177         header_size_pack, header_bytes = commons.make_header(info_size=len(rsp_bytes), md5=md5)
178         self.request.sendall(header_size_pack)
179         self.request.sendall(header_bytes)
180         self.request.sendall(rsp_bytes)
181 
182     def register(self, header_dic, req_dic):  # {'cmd':'register','info_size':0,'resultCode':0,'resultDesc':None}
183         username = req_dic['user_info']['username']
184         # 更新数据库,并制作响应信息字典
185         if not os.path.isfile(os.path.join(settings.db_file, '%s.json' % username)):
186             # 更新数据库
187             user_info = dict()
188             user_info['username'] = username
189             user_info['password'] = req_dic['user_info']['password']
190             user_info['home'] = os.path.join(settings.user_home_dir, username)
191             user_info['quota'] = settings.user_quota * (1024 * 1024)
192             commons.save_to_file(user_info, os.path.join(settings.db_file, '%s.json' % username))
193             resultCode = 0
194             resultDesc = None
195             # 创建家目录
196             if not os.path.exists(os.path.join(settings.user_home_dir, username)):
197                 os.mkdir(os.path.join(settings.user_home_dir, username))
198             self.server_logger.info('client[%s:%s] 注册用户[%s]成功' % (self.client_address[0], self.client_address[1], username))
199         else:
200             resultCode = 1
201             resultDesc = '该用户已存在,注册失败'
202             self.server_logger.warning('client[%s:%s] 注册用户[%s]失败:%s' % (self.client_address[0], self.client_address[1],
203                                                                         username, resultDesc))
204         # 响应客户端注册请求
205         self.response(resultCode=resultCode, resultDesc=resultDesc)
206 
207     @staticmethod
208     def auth(req_dic):
209         # print(req_dic['user_info'])
210         user_info = None
211         status_code = 201
212         try:
213             req_username = req_dic['user_info']['username']
214             db_file = os.path.join(settings.db_file, '%s.json' % req_username)
215             # 验证用户名密码,并制作响应信息字典
216             if not os.path.isfile(db_file):
217                 status_code = 202
218             else:
219                 with open(db_file, 'r') as f:
220                     user_info_db = json.load(f)
221                 if user_info_db['password'] == req_dic['user_info']['password']:
222                     status_code = 200
223                     user_info = user_info_db
224             return status_code, user_info
225         # 捕获  客户端鉴权请求时发送一个空字典或错误的字典  的异常
226         except KeyError:
227             return 201, user_info
228 
229     def login(self, header_dic, req_dic):
230         # 鉴权
231         status_code, user_info = self.auth(req_dic)
232         # 响应客户端登陆请求
233         self.response(user_info=user_info, resultCode=status_code)
234 
235     def query_quota(self, header_dic, req_dic):
236         used_quota = None
237         total_quota = None
238         # 鉴权
239         status_code, user_info = self.auth(req_dic)
240         # 查询配额
241         if status_code == 200:
242             used_quota = commons.getFileSize(user_info['home'])
243             total_quota = user_info['quota']
244         # 响应客户端配额查询请求
245         self.response(resultCode=status_code, total_quota=total_quota, used_quota=used_quota)
246 
247     @staticmethod
248     def parse_file_path(req_path, cur_path):
249         req_path = req_path.replace(r'/', '\\')
250         req_path = req_path.replace(r'//', r'/', )
251         req_path = req_path.replace('\\\\', '\\')
252         req_path = req_path.replace('~\\', '', 1)
253         req_path = req_path.replace(r'~', '', 1)
254         req_paths = re.findall(r'[^\\]+', req_path)
255         cur_paths = re.findall(r'[^\\]+', cur_path)
256         cur_paths.extend(req_paths)
257         cur_paths[0] += '\\'
258         while '.' in cur_paths:
259             cur_paths.remove('.')
260         while '..' in cur_paths:
261             for index, item in enumerate(cur_paths):
262                 if item == '..':
263                     cur_paths.pop(index)
264                     cur_paths.pop(index - 1)
265                     break
266         return cur_paths
267 
268     def cd(self, header_dic, req_dic):
269         cmds = header_dic['cmd'].split()
270         # 鉴权
271         status_code, user_info = self.auth(req_dic)
272         home = os.path.join(settings.user_home_dir, user_info['username'])
273         # 先定义响应信息
274         returnPath = req_dic['user_info']['cur_path']
275         if status_code == 200:
276             if len(cmds) != 1:
277                 # 解析cd的真实路径
278                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
279                 cd_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
280                 print('cd解析后的路径:', cd_path)
281                 if os.path.isdir(cd_path):
282                     if home in cd_path:
283                         resultCode = 300
284                         returnPath = cd_path.replace('%s\\' % settings.user_home_dir, '', 1)
285                     else:
286                         resultCode = 302
287                 else:
288                     resultCode = 302
289             else:
290                 resultCode = 301
291         else:
292             resultCode = 201
293         # 响应客户端的cd命令结果
294         print('cd发送给客户端的路径:', returnPath)
295         self.response(resultCode=resultCode, returnPath=returnPath)
296 
297     def ls(self, header_dic, req_dic):
298         cmds = header_dic['cmd'].split()
299         # 鉴权
300         status_code, user_info = self.auth(req_dic)
301         home = os.path.join(settings.user_home_dir, user_info['username'])
302         # 先定义响应信息
303         returnFilenames = None
304         if status_code == 200:
305             if len(cmds) <= 2:
306                 # 解析ls的真实路径
307                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
308                 if len(cmds) == 2:
309                     ls_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
310                 else:
311                     ls_path = cur_path
312                 print('ls解析后的路径:', ls_path)
313                 if os.path.isdir(ls_path):
314                     if home in ls_path:
315                         returnCode, filenames = commons.getFile(ls_path, home)
316                         resultCode = 300
317                         returnFilenames = filenames
318                     else:
319                         resultCode = 302
320                 else:
321                     resultCode = 302
322             else:
323                 resultCode = 301
324         else:
325             resultCode = 201
326         # 响应客户端的ls命令结果
327         self.response(resultCode=resultCode, returnFilenames=returnFilenames)
328 
329     def rm(self, header_dic, req_dic):
330         cmds = header_dic['cmd'].split()
331         # 鉴权
332         status_code, user_info = self.auth(req_dic)
333         home = os.path.join(settings.user_home_dir, user_info['username'])
334         # 先定义响应信息
335         if status_code == 200:
336             if len(cmds) == 2:
337                 # 解析rm的真实路径
338                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
339                 rm_path = os.path.join('', *self.parse_file_path(os.path.dirname(cmds[1]), cur_path))
340                 rm_file = os.path.join(rm_path, os.path.basename(cmds[1]))
341                 print('rm解析后的文件或文件夹:', rm_file)
342                 if os.path.exists(rm_file):
343                     if home in rm_file:
344                         commons.rmdirs(rm_file)
345                         resultCode = 300
346                     else:
347                         resultCode = 302
348                 else:
349                     resultCode = 302
350             else:
351                 resultCode = 301
352         else:
353             resultCode = 201
354         # 响应客户端的rm命令结果
355         self.response(resultCode=resultCode)
356 
357     def mkdir(self, header_dic, req_dic):
358         cmds = header_dic['cmd'].split()
359         # 鉴权
360         status_code, user_info = self.auth(req_dic)
361         home = os.path.join(settings.user_home_dir, user_info['username'])
362         # 先定义响应信息
363         if status_code == 200:
364             if len(cmds) == 2:
365                 # 解析rm的真实路径
366                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
367                 mkdir_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
368                 print('mkdir解析后的文件夹:', mkdir_path)
369                 if not os.path.isdir(mkdir_path):
370                     if home in mkdir_path:
371                         os.makedirs(mkdir_path)
372                         resultCode = 300
373                     else:
374                         resultCode = 302
375                 else:
376                     resultCode = 303
377             else:
378                 resultCode = 301
379         else:
380             resultCode = 201
381         # 响应客户端的mkdir命令结果
382         self.response(resultCode=resultCode)
383 
384     def get(self, header_dic, req_dic):
385         """客户端下载文件"""
386         cmds = header_dic['cmd'].split()  # ['get', 'a.txt', 'download']
387         # 解析需要get文件的真实路径
388         cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
389         get_file = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
390         # 鉴权
391         status_code, user_info = self.auth(req_dic)
392         home = os.path.join(settings.user_home_dir, user_info['username'])
393         # 解析断点续传信息
394         position = 0
395         if req_dic['resume'] and isinstance(req_dic['position'], int):
396             position = req_dic['position']
397         # 先定义响应信息
398         resultCode = 300
399         FileSize = None
400         FileMd5 = None
401         if status_code == 200:
402             if 1 < len(cmds) < 4:
403                 print('get解析后的路径:', get_file)
404                 if os.path.isfile(get_file):
405                     if home in get_file:
406                         FileSize = commons.getFileSize(get_file)
407                         if position >= FileSize != 0:
408                             resultCode = 304
409                         else:
410                             resultCode = 300
411                             FileSize = FileSize
412                             FileMd5 = commons.getFileMd5(get_file)
413                     else:
414                         resultCode = 302
415                 else:
416                     resultCode = 302
417             else:
418                 resultCode = 301
419         else:
420             resultCode = 201
421         # 响应客户端的get命令结果
422         self.response(resultCode=resultCode, FileSize=FileSize, FileMd5=FileMd5)
423         if resultCode == 300:
424             self.request.setblocking(False)
425             # 发送文件数据
426             with open(get_file, 'rb') as f:
427                 f.seek(position)
428                 for line in f:
429                     self.request.send(line)
430                     position += len(line)
431 
432     def put(self, header_dic, req_dic):
433         cmds = header_dic['cmd'].split()  # ['put', 'download/a.txt', 'video']
434         # 解析需要put文件的真实路径
435         cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
436         if len(cmds) == 3:
437             put_file = os.path.join(os.path.join('', *self.parse_file_path(cmds[2], cur_path)),
438                                     os.path.basename(cmds[1]))
439         else:
440             put_file = os.path.join(cur_path, os.path.basename(cmds[1]))
441         if not self.dic[self.request]:
442             self.dic[self.request]['action'] = 'put'
443             self.dic[self.request]['header_dic'] = header_dic
444             self.dic[self.request]['req_dic'] = req_dic
445             self.dic[self.request]['position'] = 0
446             # 鉴权
447             status_code, user_info = self.auth(req_dic)
448             home = os.path.join(settings.user_home_dir, user_info['username'])
449             # 查询配额
450             used_quota = commons.getFileSize(user_info['home'])
451             total_quota = user_info['quota']
452             # 先定义响应信息
453             if status_code == 200:
454                 if 1 < len(cmds) < 4:
455                     print('put解析后的文件:', put_file)
456                     put_path = os.path.dirname(put_file)
457                     if os.path.isdir(put_path):
458                         if home in put_path:
459                             if (req_dic['FileSize'] + used_quota) <= total_quota:
460                                 resultCode = 300
461                             else:
462                                 resultCode = 305
463                         else:
464                             resultCode = 302
465                     else:
466                         resultCode = 302
467                 else:
468                     resultCode = 301
469             else:
470                 resultCode = 201
471             # 响应客户端的put命令结果
472             self.response(resultCode=resultCode)
473             if resultCode == 300:
474                 # 接收文件数据,写入文件
475                 recv_size = 0
476                 with open(put_file, 'wb') as f:
477                     while recv_size < req_dic['FileSize']:
478                         file_data = self.request.recv(1024)
479                         f.write(file_data)
480                         recv_size += len(file_data)
481                         self.dic[self.request]['position'] = recv_size
482         else:
483             # 接收文件数据,写入文件
484             recv_size = self.dic[self.request]['position']
485             with open(put_file, 'ab') as f:
486                 while recv_size < req_dic['FileSize']:
487                     file_data = self.request.recv(1024)
488                     f.write(file_data)
489                     recv_size += len(file_data)
490                     self.dic[self.request]['position'] = recv_size
491         # 校检文件md5一致性
492         if commons.getFileMd5(put_file) == req_dic['FileMd5']:
493             resultCode = 400
494             print('\033[42;1m文件md5校检结果一致\033[0m')
495             print('\033[42;1m文件上传成功,大小:%d,文件名:%s\033[0m' % (req_dic['FileSize'], put_file))
496         else:
497             os.remove(put_file)
498             resultCode = 401
499             print('\033[31;1m文件md5校检结果不一致\033[0m')
500             print('\033[42;1m文件上传失败\033[0m')
501         # 返回上传文件是否成功响应
502         self.response(resultCode=resultCode)
503         self.dic[self.request] = {}
server.py

 

 

posted on 2018-01-16 17:38  泽桐-Merlin  阅读(587)  评论(0编辑  收藏  举报