Python修炼之路-Socket

网络编程

socket套接字

socket通常也称作"套接字",用于描述IP地址和端口,是一个通信链的句柄,应用程序通常通过“套接字”向网络发出请求或者应答网络请求。

socket模块是针对 服务器端 和 客户端Socket 进行【打开】【读写】【关闭】

TCP协议Socket

tcp是基于链接的,必须先启动服务端,然后再启动客户端去连接服务端。

server端

import socket

s = socket.socket()
s.bind(('127.0.0.1',8888))     # 把地址绑定到套接字
s.listen()                     # 监听链接
conn,addr = s.accept()         # 接受客户端链接, conn是客户端连接过来而在服务器端为其生成的一个连接实例
print(addr) 
print(conn)
data = conn.recv(1024)         # 接收客户端信息
print(data)                    # 打印客户端信息
conn.send(b'hi')               # 向客户端发送信息
conn.close()  # 关闭客户端套接字
s.close()  # 关闭服务器套接字(可选)

client端

import socket

s = socket.socket()             # 创建客户套接字
s.connect(('127.0.0.1',8888))   # 尝试连接服务器
s.send(b'hello!')
res = s.recv(1024)             # 对话(发送/接收)
print(res)
s.close()                      # 关闭客户套接字

UDP协议Socket

udp是无链接的,启动服务之后可以直接接受消息,不需要提前建立链接。

server

import socket
udp_s = socket.socket(type=socket.SOCK_DGRAM)   # 创建一个服务器的套接字
udp_s.bind(('127.0.0.1',9000))                  # 绑定服务器套接字
msg,addr = udp_s.recvfrom(1024)
print(msg)
udp_s.sendto(b'hi',addr)                 # 对话(接收与发送)
udp_s.close()                            # 关闭服务器套接字

client

import socket

ip_port = ('127.0.0.1',9000)
udp_c=socket.socket(type=socket.SOCK_DGRAM)
udp_c.sendto(b'hello',ip_port)
back_msg,addr=udp_c.recvfrom(1024)
print(back_msg.decode('utf-8'),addr)

socket参数

创建一个socket

socket.socket( family=AF_INET, type=SOCK_STREAM,  proto=0,  fileno=None)

socket参数解析

参数 解析
family 地址簇:socket.AF_INET : IPv4(默认)
socket.AF_INET6: IPv6
socket.AF_UNIX: 只能够用于单一的Unix系统进程间通信,使用本地 socket 文件来通信
type socket.SOCK_STREAM: 默认类型,数据流,TCP协议,有保障的(即能保证数据正确传送到对方)面向连接的SOCKET,多用于资料传送。
socket.SOCK_DGRAM: 数据报式,UDP协议,无保障的面向消息的socket,多用于在网络上发广播信息。
proto 协议号,通常为零,可以省略,与特定的地址家族相关的协议,如果是 0,则系统就会根据地址格式和套接类别,自动选择一个合适的协议;在地址族为AF_CAN的情况下,协议应为CAN_RAW或CAN_BCM之一。
fileno 如果指定了fileno,则其他参数将被忽略,导致带有指定文件描述符的套接字返回。与socket.fromfd()不同,fileno将返回相同的套接字,而不是重复的。这可能有助于使用socket.close()关闭一个独立的插座。

socket常用方法

方法 说明
s.bind(address) 将套接字绑定到地址。address地址的格式取决于地址族。在AF_INET下,以元组(host,port)的形式表示地址
s.listen(backlog) 开始监听传入连接。backlog指定在拒绝连接之前,可以挂起的最大连接数量。
backlog等于5,表示内核已经接到了连接请求,但服务器还没有调用accept进行处理的连接个数最大为5。
这个值不能无限大,因为要在内核中维护连接队列。
s.setblocking(bool) 是否阻塞,默认True,如果设置False,那么accept和recv时一旦无数据,则报错。
s.accept() 接受连接并返回(conn,address),其中conn是新的套接字对象,可以用来接收和发送数据。address是连接客户端的地址。
接收TCP 客户的连接(阻塞式)等待连接的到来。
s.connect(address) 连接到address处的套接字。一般,address的格式为元组(hostname,port),如果连接出错,返回socket.error错误。
s.connect_ex(address) 同上,只不过会有返回值,连接成功时返回 0 ,连接失败时候返回编码,例如:10061。
s.close() 关闭套接字。
s.recv(bufsize[,flag]) 接受套接字的数据。数据以字符串形式返回,bufsize指定最多可以接收的数量。flag提供有关消息的其他信息,通常可以忽略。
s.recvfrom(bufsize[.flag]) 与recv()类似,但返回值是(data,address)。其中data是包含接收数据的字符串,address是发送数据的套接字地址。
s.send(string[,flag]) 将string中的数据发送到连接的套接字。返回值是要发送的字节数量,该数量可能小于string的字节大小。即:可能未将指定内容全部发送。
s.sendall(string[,flag]) 将string中的数据发送到连接的套接字,但在返回之前会尝试发送所有数据。成功返回None,失败则抛出异常。
内部通过递归调用send,将所有内容发送出去。
s.sendto(string[,flag],address) 将数据发送到套接字,address是形式为(ipaddr,port)的元组,指定远程地址。返回值是发送的字节数。该函数主要用于UDP协议。
s.settimeout(timeout) 设置套接字操作的超时期,timeout是一个浮点数,单位是秒。值为None表示没有超时期。一般,超时期应该在刚创建套接字时设置,因为它们可能用于连接的操作(如 client 连接最多等待5s )
s.getpeername() 返回连接套接字的远程地址。返回值通常是元组(ipaddr,port)。
s.getsockname() 返回套接字自己的地址。通常是一个元组(ipaddr,port)
s.fileno() 套接字的文件描述符

黏包现象

黏包:同时执行多条命令之后,得到的结果很可能只有一部分,在执行其他命令的时候又接收到之前执行的另外一部分结果,这种显现就是黏包。

TCP协议会出现黏包问题,UDP协议不会出现黏包。

server

from socket import *
import subprocess

ip_port=('127.0.0.1',8888)
BUFSIZE=1024

s = socket(AF_INET,SOCK_STREAM)
s.setsockopt(SOL_SOCKET,SO_REUSEADDR,1)
s.bind(ip_port)
s.listen(5)

while True:
    conn,addr = s.accept()
    print('客户端',addr)
    while True:
        cmd = conn.recv(BUFSIZE)
        if len(cmd) == 0:break
        res = subprocess.Popen(cmd.decode('utf-8'),shell=True,
                         stdout=subprocess.PIPE,
                         stdin=subprocess.PIPE,
                         stderr=subprocess.PIPE)
        stderr = res.stderr.read()
        stdout = res.stdout.read()
        conn.send(stderr)
        conn.send(stdout)

client

import socket
BUFSIZE
=1024 ip_port=('127.0.0.1',8888) c = socket.socket(socket.AF_INET,socket.SOCK_STREAM) res = c.connect_ex(ip_port) while True: msg=input('>>: ').strip() if len(msg) == 0:continue if msg == 'quit':break c.send(msg.encode('utf-8')) act_res=c.recv(BUFSIZE) print(act_res.decode('utf-8'),end='')

黏包产生

发送方的缓存机制

发送端需要等缓冲区满才发送出去,造成粘包(发送数据时间间隔很短,数据了很小,会合到一起,产生粘包)

# server
from socket import *
ip_port=('127.0.0.1',8080)

s = socket(AF_INET,SOCK_STREAM)
s.bind(ip_port)
s.listen(5)

conn,addr = s.accept()

data1=conn.recv(10)
data2=conn.recv(10)
print('----->',data1.decode('utf-8'))
print('----->',data2.decode('utf-8'))
conn.close()

# client
import socket
BUFSIZE=1024
ip_port=('127.0.0.1',8080)

s=socket.socket(socket.AF_INET,socket.SOCK_STREAM)
res=s.connect_ex(ip_port)

s.send('hello'.encode('utf-8'))
s.send('egg'.encode('utf-8'))

接收方的缓存机制

接收方不及时接收缓冲区的包,造成多个包接收(客户端发送了一段数据,服务端只收了一小部分,服务端下次再收的时候还是从缓冲区拿上次遗留的数据,产生粘包)

#  服务端
from socket import *
ip_port=('127.0.0.1',8080)

s = socket(AF_INET,SOCK_STREAM)
s.bind(ip_port)
s.listen(5)

conn,addr = s.accept()

data1=conn.recv(2)    # 一次没有收完整
data2=conn.recv(10)  # 下次收的时候,会先取旧的数据,然后取新的

print('----->',data1.decode('utf-8'))
print('----->',data2.decode('utf-8'))
conn.close()

#  客户端
import socket
BUFSIZE=1024
ip_port=('127.0.0.1',8080)
c = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
res = c.connect_ex(ip_port)
c.send('hello egg'.encode('utf-8'))

注:

黏包现象只发生在tcp协议中:

  • 从表面上看,黏包问题主要是因为发送方和接收方的缓存机制、tcp协议面向流通信的特点。
  • 实际上,主要还是因为接收方不知道消息之间的界限,不知道一次性提取多少字节的数据所造成的

黏包解决方案

在发送数据前,把将要发送的字节流总大小给接收端,接收端通过循环多次接受并判断是否接收完所有数据。

注意

     接收端判断总接收字节大小时(recv_size < length),会出现接收到的总字节数大于发送端发送过来的总字节数。这是因为,len(一个汉字)=1,len('国'.encode())=3;如果客户端计算字节数不是统计encode()之后的数目,而客户端统计encode后的数目,那么客户端统计的总字节数大于发送端发送过来的总字节数。这也是为什么判断recv_size < length退出循环,而不是判断两者相等退出循环。

#  server
import socket,subprocess
ip_port=('127.0.0.1',8080)
s = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(ip_port)
s.listen(5)
while True:
    conn,addr=s.accept()while True:
        msg = conn.recv(1024)
        if not msg:break
        res=subprocess.Popen(msg.decode('utf-8'),shell=True,stdin=subprocess.PIPE,stderr=subprocess.PIPE,stdout=subprocess.PIPE)
        err = res.stderr.read()
        if err:
            ret = err
        else:
            ret = res.stdout.read()
        data_length = len(ret)
        conn.send(str(data_length).encode('utf-8'))
        data = conn.recv(1024).decode('utf-8')
        if data == 'recv_ready':
            conn.sendall(ret)
    conn.close()

# client
import socket,time
c = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
res = c.connect_ex(('127.0.0.1',8080))
while True:
    msg=input('>>: ').strip()
    if len(msg) == 0:continue
    if msg == 'quit':break
    c.send(msg.encode('utf-8'))
    length = int(s.recv(1024).decode('utf-8'))
    s.send('recv_ready'.encode('utf-8'))
    recv_size=0
    data = b''
    while recv_size < length:
        data+=s.recv(1024)
        recv_size+=len(data)
    print(data.decode('utf-8'))

端口占用问题

Windows

# 加入一条socket配置,重用ip和端口
import socket
from socket import SOL_SOCKET,SO_REUSEADDR
sk = socket.socket()
sk.setsockopt(SOL_SOCKET,SO_REUSEADDR,1)  # 就是它,在bind前加
sk.bind(('127.0.0.1',8898))  # 把地址绑定到套接字
sk.listen()  # 监听链接
conn,addr = sk.accept() # 接受客户端链接
ret = conn.recv(1024)   # 接收客户端信息
print(ret)  # 打印客户端信息
conn.send(b'hi')  # 向客户端发送信息
conn.close()  # 关闭客户端套接字
sk.close()  # 关闭服务器套接字(可选)

Linux

Unix系统可以找到占用端口的进程,kill掉进程

 


socketserver

四种基本server类型

  • class socketserver.TCPServer(server_addressRequestHandlerClassbind_and_activate=True)       
  • class socketserver.UDPServer(server_addressRequestHandlerClassbind_and_activate=True)
  • class socketserver.UnixStreamServer(server_addressRequestHandlerClassbind_and_activate=True)
  • class socketserver.UnixDatagramServer(server_addressRequestHandlerClass,bind_and_activate=True)

ThreadingTCPServer使用:

  1. 创建一个继承socketserver.BaseRequestHandler的类:class MyTCPHandler(socketserver.BaseRequestHandler)
  2. 重写handler方法: def handle(self):
  3. 实例化ThreadingTCPServer: server = socketserver.ThreadingTCPServer((HOST, PORT), MyTCPHandler)
  4. 启动server: server.serve_forever()

server

import socketserver

class Myserver(socketserver.BaseRequestHandler):
    def handle(self):
        self.data = self.request.recv(1024).strip()
        print("{} wrote:".format(self.client_address[0]))
        print(self.data)
        self.request.sendall(self.data.upper())

if __name__ == "__main__":
    HOST, PORT = "127.0.0.1", 9999

    # 设置allow_reuse_address允许服务器重用地址
    socketserver.TCPServer.allow_reuse_address = True
    # 创建一个server, 将服务地址绑定到127.0.0.1:9999
    server = socketserver.TCPServer((HOST, PORT),Myserver)
    # 让server永远运行下去,除非强制停止程序
    server.serve_forever()

client

import socket

HOST, PORT = "127.0.0.1", 9999
data = "hello"

# 创建一个socket链接,SOCK_STREAM代表使用TCP协议
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
    sock.connect((HOST, PORT))          # 链接到客户端
    sock.sendall(bytes(data + "\n", "utf-8")) # 向服务端发送数据
    received = str(sock.recv(1024), "utf-8")# 从服务端接收数据

print("Sent:     {}".format(data))
print("Received: {}".format(received))

实例

import os
import json
import hashlib
import socketserver
from conf import setting
from core import sql


class MyTCPHandler(socketserver.BaseRequestHandler):

    def handle(self):
        while True:
            try:
                self.server_command = self.request.recv(1024).strip()
                print('receive command: ',self.server_command.decode() )
                self.server_command = json.loads(self.server_command .decode())
                if hasattr(self,self.server_command['command_key']):
                    getattr(self,self.server_command['command_key'])()
            except ConnectionResetError as e:
                print('\033[31;1mConnection[%s] failed...\033[0m' % (
                str(self.client_address[0])))
                #print('\033[31;1mConnection[%s:%s] failed...\033[0m' % (str(self.client_address[0]), str(self.client_address[1])))
                break
            except Exception as e:
                print('\033[31;1mConnection[%s] failed...\033[0m' % (
                    str(self.client_address[0])))
                break

    def register(self):
        """
                function: register new account
                :return:
                """
        account_path = '%s/%s' % (setting.DATABASE['db'], self.server_command['command_param'])
        if not os.path.isfile(account_path):                           # judge this account if exist
            self.request.send('False'.encode('utf-8'))
            password = self.request.recv(1024).decode()                # receive register password
            table_name = '%s.%s' % (setting.DATABASE['db'], self.server_command['command_param'])
            account_home = '%s/%s' % (setting.DATABASE['home'], self.server_command['command_param'])
            os.makedirs(account_home)                                  # make dir for new user
            user_data = {
                'account': self.server_command['command_param'],      # 用户账号
                'password': password,                                 # 密码
                'space_size_total': 102400000,                        # 默认分配空间大小
                'space_size_used': 0,                                 # 已经使用空间
                'home_dir': account_home                              # 家目录
            }
            sql.Sql('insert into %s value %s' % (table_name, user_data))  # save account info
            self.request.send("create account successful".encode('utf-8'))
        else:
            self.request.send('True'.encode('utf-8'))

    def login(self, *args):
        """
        function: login ftp server
        :return:
        """
        self.table_name = '%s.%s' % (setting.DATABASE['db'], self.server_command['command_param'])
        self.read_execute = sql.Sql('select * from %s' % self.table_name)
        if self.read_execute.status:
            # print('user info: ', self.read_execute.status[0], type(self.read_execute.status[0]))
            self.request.send(self.read_execute.status[0].encode('utf-8'))
        else:
            self.request.send('None'.encode('utf-8'))

    def ls(self):
        """
        列出当前目录或所查找目录的下一级内容
        :return:
        """

        dir_dic = {}
        if os.path.isdir(self.server_command['command_param']):
            chitem = os.listdir(self.server_command['command_param'])
            for path in chitem:
                chpath = '%s/%s' % (self.server_command['command_param'], path)
                if os.path.isdir(chpath):
                    dir_dic[path] = 'dir'
                else:
                    dir_dic[path] = 'file'
        self.request.send(json.dumps(dir_dic).encode('utf-8'))

    def upload(self):
        """
        上传文件
        :return:
        """

        filepath = self.server_command['command_filelocation']            # 文件上传路径
        filename = os.path.basename(self.server_command['command_param']) # 文件名
        file_size = self.server_command['command_filesize']               # 文件大小
        file_info = '%s/%s' %(filepath, filename)                         # 带路径的文件名
        self.read_execute.status[0]= eval(self.read_execute.status[0])    # 字符串转换为字典
        if file_size > int(self.read_execute.status[0]['space_size_total'])- int(self.read_execute.status[0]['space_size_used']):    # check space大小
            self.request.send('false'.encode('utf-8'))                    # 判断磁盘空间不足
        else:
            self.request.send('true'.encode('utf-8'))                     # 磁盘空间足够大
        check = self.request.recv(1024).decode()
        if check == 'false':                                        # 可用磁盘不足,退出
            return 1
        if os.path.isfile(file_info):                               # 文件存在,则查看文件大小
            file_offset = os.stat(file_info).st_size
        else:
            file_offset = 0
        self.request.send(str(file_offset).encode('utf-8'))         # 发送server端文件大小
        f = open(file_info, "ab+")                                  # 打开文件
        received_size = 0
        m_check = hashlib.md5()                                     # used hashlib.md5 to encryption
        while received_size < file_size - file_offset:              # 判断接收到的文件大小
            data = self.request.recv(1024)                          # 接受最大为1024大小文件
            f.write(data)                                           # 写入文件
            received_size += len(data)                              # 累加收到文件大小
            m_check.update(data)                                    # 更新md5校验
            self.request.send(m_check.hexdigest().encode())         # 发送校验码,以及防粘连
        else:
            self.read_execute.status[0]['space_size_used'] +=  file_size  # 上传成功修改使用空间大小
            sql.Sql("update %s from %s" % (self.read_execute.status[0], self.table_name))
        f.close()

    def info(self):
        """
        返回个人信息
        :return:
        """

        self.table_name = '%s.%s' % (setting.DATABASE['db'],self.server_command['command_param'] )  # 解析用户数据库名
        self.read_execute = sql.Sql('select * from %s' % self.table_name)   # 获取用户信息
        if self.read_execute.status:                                        # 用户存在
            self.request.send(self.read_execute.status[0].encode('utf-8'))
        else:                                                               # 用户不存在
            self.request.send(json.dumps('None').encode('utf-8'))

    def download(self):
        """"""
        if os.path.isfile(self.server_command['command_param']):               # 文件是否在本地存在
            file_size = os.stat(self.server_command['command_param']).st_size  # 文件大小
            self.request.send(str(file_size).encode('utf-8'))                  # 发送文件大小
        else:
            file_size = 0                                                  # 如果文件不存在,在大小为0
            self.request.send(str(file_size).encode('utf-8'))
            return 1
        client_file_size = int(self.request.recv(1024).decode())            # client端目前文件大小
        if client_file_size == file_size:
            return 1
        m_filecheck = hashlib.md5()
        file_obj = open(self.server_command['command_param'], 'rb')         # 打开文件
        file_obj.seek(client_file_size)                                     # 文件位置跳转,用于断点续传
        send_size = 1024                                                    # 每次发送文件大小
        while True:
            file_data = file_obj.read(send_size)                            # 读取文件
            if file_data:                                                   # 判断文件是否读取完成
                self.request.send(file_data)                                # 发送读取文件内容
                file_check = self.request.recv(1024)                        # 接受server确认信息
                m_filecheck.update(file_data)                               # 文件校验
            else:
                self.request.send(m_filecheck.hexdigest().encode())
                file_obj.close()                                            # 关闭文件
                break

    def cd(self):
        """
        切换目录
        :return:
        """
        if os.path.isdir(self.server_command['command_param']):
            cd_dir = self.server_command['command_param']
        else:
            cd_dir = 'None'
        self.request.send(json.dumps(cd_dir).encode('utf-8'))

    def mkdir(self):
        """
        新建文件夹
        :return:
        """
        if os.path.isdir(self.server_command['command_param']):  # 判断新建目录是否存在
            check_data = 'File[%s] is already exist ' %self.server_command['command_param']
        else:
            os.mkdir(self.server_command['command_param'])       # 不存在则新建目录
            check_data = 'Create successful'
        self.request.send(check_data.encode('utf-8'))

    def change_space(self):
        """
        配置用户磁盘配额
        :return:
        """
        new_info = self.server_command['command_param'].split()          # 参数解析
        self.table_name = '%s.%s' % (setting.DATABASE['db'], new_info[0])   # 解析用户数据库名
        print('tablename:',self.table_name)
        self.read_execute = sql.Sql('select * from %s' % self.table_name)  # 获取用户信息
        if self.read_execute.status:  # 用户存在
            account_data = eval(self.read_execute.status[0])
            print('info::',self.read_execute.status[0])
            print('info type:', type(account_data))
            account_data['space_size_total'] = str(new_info[1])               # 修改磁盘大小
            sql.Sql("update %s from %s" % (account_data, self.table_name))
            self.request.send('chang successful'.encode('utf-8'))
        else:  # 用户不存在
            self.request.send('No such user, chang failed'.encode('utf-8'))




if __name__ == "__main__":
    pass
else:
    HOST, PORT = "localhost", 9000
    # Create the server, binding to localhost on port 9999
    server = socketserver.ThreadingTCPServer((HOST, PORT), MyTCPHandler)
    server.serve_forever()
FTP server

 

 

import json
import hashlib
from conf import setting
from core import auth
import os
import re
import sys
import socket

def admin_required(func):
    """
    验证用户是否登录,装饰器
    :param func:
    :return:
    """
    def wrapper(self,*args,**kwargs):
        if self.user_data['account_id'] == 'admin':   # 判断用户是否为admin
            return func(self,*args, **kwargs)
        else:
            print("No permission... ")
            return False
    return wrapper


class FtpClient(object):

    def __init__(self):
        self.client = socket.socket()
        self.command = {}
        self.user_data = {}
        self.user_data = {                   # user data
            'account_id': None,
            'access_authenticated': None,    # judge user if authenticated
            'dir': {},                       # user operation dir info
            'home': None,                    # user home dir
        }
        self.connect(setting.SERVER_IP, setting.SERVER_PORT)

    def login(self):
        self.init_data = auth.login(self)  # call login authenticated function
        # self.init_data = {'password': 'e99a18c428cb38d5f260853678922e03', 'space_size_used': 0, 'home_dir': 'C:\\D\\personal_data\\workspace\\FTP\\server/home/greath', 'account': 'greath', 'space_size_total': 102400000}
        # print('receive: ',self.init_data )
        if self.init_data:
            self.user_data['dir']['current_dir'] = self.init_data['home_dir']  # 初始化当前目录
            #self.user_data['dir']['cddir'] = self.init_data['home_dir']        # 查找目录
            #self.user_data['dir']['chdir'] = {}                                # 打印查找目录
            self.user_data['home'] = self.init_data['home_dir']  # 登入成功后记录用户家目录
            self.command['command_key'] = 'ls'
            self.login_interview()  # 登入成功后显示界面
        else:
            print('\033[31;0m login failed\033[0m ')

    def login_interview(self):
        """
        登入成功后,显示界面
        :return:
        """
        while True:
            try:
                if hasattr(self, self.command['command_key']):
                    getattr(self,self.command['command_key'])()
                    self.command['command_key'] = ' '
                    self.command['command_param'] = ' '
                else:
                    print('\033[41;0mEnter[%s] error...\033[0m' % self.command)

            except Exception as e:
                print('\033[0;31mError Found[%s]' %e)
                exit()
            command_pro = re.search(r'/home/.*', self.user_data['dir']['current_dir'].replace('\\', '/'))
            if not command_pro:
                pass                  # exit("can't find file[%s]" % self.dir['current_dir'])
            self.command_pro = command_pro.group()        # 输入提示
            while True:
                command = input('[%s:%s]# ' % (self.user_data['account_id'], self.command_pro)).strip()
                if len(command) > 0:
                    break
            self.command_paras(command)

    @admin_required
    def change_space(self):
        """
        改变用户磁盘配额
        :return:
        """
        self.client.send(json.dumps(self.command).encode('utf-8'))
        data = self.client.recv(1024).decode()
        print(data)


    def command_paras(self, command):
        """
        输入命令解析
        :param command:
        :return:
        """
        command_l = command.split(maxsplit=1)     # 提取命令关键字和参数
        # print(command_l)
        self.command['command_key'] = command_l[0]  # 关键字
        if len(command_l) > 1:
            self.command['command_param'] = command_l[1]   #参数
        else:
            self.command['command_param'] = None

    def mkdir(self):
        """
        新建文件夹
        :return:
        """
        self.command['command_param'] = '%s/%s' %(self.user_data['dir']['current_dir'],self.command['command_param'])  # 相对路径创建文件夹
        self.client.send(json.dumps(self.command).encode('utf-8'))
        data = self.client.recv(1024).decode()
        print(data)

    def ls(self):
        """"""
        if not self.command.get('command_param') or self.command['command_param'] == None:
            self.command['command_param'] = self.user_data['dir']['current_dir']
        self.client.send(json.dumps(self.command).encode('utf-8'))  # 发送ls命令
        ls_data = json.loads(self.client.recv(10240).decode())      # 接受查询数据
        print("\033[36;1m|-- %s/\033[0m" % self.user_data['dir']['current_dir'])
        for path in ls_data:
            if ls_data[path] == 'dir':
                print("\033[32;1m    |-- %s\033[0m" % path)
            elif ls_data[path] == 'file':
                print('\033[30;1m    |-- %s\033[0m' % path)
        print()

    def cd(self):
        """
        切换目录
        :return:
        """
        if self.command['command_param'] == '.':                                  # 返回当前目录
            self.command['command_param'] = self.user_data['dir']['current_dir']
        elif self.command['command_param'] == '..':                               # 返回上一层目录
            if self.command_pro == '/home/%s' % self.user_data['account_id']:     # 用户只能访问自己的家目录
                print('No access permission to /home')
                return 1
            self.command['command_param'] = os.path.dirname(self.user_data['dir']['current_dir'].strip('/'))
        elif self.command['command_param'] not in self.user_data['dir']['current_dir'].replace('\\', '/'):
            self.command['command_param']='%s/%s' %(self.user_data['dir']['current_dir'],self.command['command_param'])
        self.client.send(json.dumps(self.command).encode('utf-8'))           # 发送cd命令
        cd_data = json.loads(self.client.recv(10240).decode())               # 接受查询数据
        if cd_data != 'None':
            self.user_data['dir']['current_dir'] = cd_data
            self.command['command_param'] = self.user_data['dir']['current_dir']
            self.command['command_key'] = 'ls'
            self.ls()
        else:
            print('\033[31;1mNo such dir[%s]\033[0m' % (self.command['command_param']))

    def upload(self):
        """
        上传文件
        :return:
        """
        if os.path.isfile(self.command['command_param']):                # 文件是否在本地存在
            self.command['command_filesize'] = int(os.stat(self.command['command_param']).st_size)  # 文件大小
#            if not self.command.get('command_filelocation'):             # 如果没有输入保存路径
            self.command['command_filelocation'] = self.user_data['dir']['current_dir']   # 文件保存位置,默认为当前路径
            self.client.send(json.dumps(self.command).encode('utf-8'))   # 发送upload命令
            check = self.client.recv(1024).decode()
            if check == 'false':
                print('space is not enough, upload failed')
                self.client.send('false'.encode('utf-8'))
                return 1
            else:
                self.client.send('true'.encode('utf-8'))
            file_offset = int(self.client.recv(1024).decode())                      # 接受server端的确认信息
            if file_offset < self.command['command_filesize'] :          # 判断server端文件大小
                self.file_upload(file_offset)                            # 开始
            else:
                print('file is already exist.')
        else:                                                            # 文件在本地不存在
            print('\033[31;0mFile [%s] is not exist...\033[0m' % self.command['command_param'])
            return 1
        del self.command['command_filesize']                             # 清除命令无关key
        del self.command['command_filelocation']

    def file_upload(self, total_send_size):
        """
        上传文件
        :param init_offset:
        :return:
        """
        total_send_size = int(total_send_size)
        file_obj = open(self.command['command_param'], 'rb')     # 打开文件
        file_obj.seek(total_send_size)                           # 文件位置跳转,用于断点续传
        send_size = 1024                                         # 发送文件大小
        m_check = hashlib.md5()                                  # 文件校验
        file_check_server = 0                                    # 初始化server端文件校验码
        while total_send_size < int(self.command['command_filesize']):# 可以修改size == 0的判断
            file_data = file_obj.read(send_size)                 # 读取文件
            self.client.send(file_data)                          # 发送读取文件内容
            file_check_server = self.client.recv(1024).decode()  # 接受server确认信息
            m_check.update(file_data)                            # 文件校验
            total_send_size += send_size
            # 传输过程进度条
            percent = '{0:.0%}'.format(total_send_size/ self.command['command_filesize'])  #打印进度条
            sys.stdout.write(
                 '\r [%-100s] %s' % ('#' * int(total_send_size / self.command['command_filesize'] * 100), percent))
            sys.stdout.flush()
            if total_send_size > self.command['command_filesize'] -1024:   # 判断最后一次传输大小
                send_size = self.command['command_filesize'] - total_send_size
        else:
            print()
            if m_check.hexdigest() == file_check_server:          # 相同则上传成功
                print('upload file completed!!!')
            else:
                print('upload failed...')
            file_obj.close()                                      # 关闭文件

    def info(self):
        """
        打印个人信息
        :return:
        """
        if not self.command['command_param']:
            self.command['command_param'] = self.user_data['account_id']
        self.client.send(json.dumps(self.command).encode('utf-8'))
        data = self.client.recv(10240).decode()
        if data != 'None':
            data = eval(data)
            info = """
            -----------------info of {account}------------------
            Account:{account}
            Home:{home}
            Space_size:{space_size}
            Space_used:{space_used}        
            """.format(
                        account = data['account'],
                        home = data['home_dir'],
                        space_size = data['space_size_total'],
                        space_used = data['space_size_used']
            )
            print(info)
        else:
            print("get info failed")

    def download(self):
        """"""
        filename = self.command['command_param']
        self.command['command_param'] = '%s/%s' %(self.user_data['dir']['current_dir'], self.command['command_param'])
        self.client.send(json.dumps(self.command).encode('utf-8'))      # 发送download命令
        total_file_size = int(self.client.recv(1024).decode())          # 接收文件大小
        if total_file_size == 0:                                        # 接收来自server的确认文件是否存在
            print("file is not exist in server...")
            return 1
        received_size = 0                                                   # 初始化client端文件大小
        if os.path.isfile(filename):                                    # 判断文件是否存在
            file_size = os.stat(filename).st_size
            if total_file_size == received_size :
                print("this file is already exist")                     # 文件存在并且大小一致,返回
                self.client.send(str(total_file_size).encode('utf-8'))
                return 0
        self.client.send(str(received_size).encode('utf-8'))
        m_filecheck = hashlib.md5()
        size = 1024
        with open(filename, 'ab+') as f:
            while received_size < total_file_size:
                if received_size > total_file_size -1024:
                    size = total_file_size - received_size                 # 设置每次接受的大小
                data = self.client.recv(size)
                f.write(data)                                               # 写入数据
                # 传输过程进度条
                percent = '{0:.0%}'.format(received_size / total_file_size)  # 打印进度条
                sys.stdout.write(
                    '\r [%-100s] %s' % ('#' * int(received_size / total_file_size * 100), percent))
                sys.stdout.flush()
                self.client.send('receive_check'.encode('utf-8'))           # 防粘连
                received_size += size
                m_filecheck.update(data)
            else:
                print()                                             # 输出换行
                file_check = self.client.recv(1024).decode()        # 接收server校验码
                if file_check == m_filecheck.hexdigest():
                    print('download successful...')

    def exit(self):
        """
        退出
        :return:
        """
        exit()

    def help(self):
        msg = '''
        ls                                   # list file or folder info
        cd dir_param                         # [..] Return to the previous level directory.
        download filename                    # offer filename in relative paths 
        upload filename                      # offer filename in absolute paths 
        mkdir  folder_name                   # Creat new folder
        info   account_id                    # return user info
        chang_space account_id new_size      # chang user space size configure
        exit                                 # exit system
        '''
        print(msg)

    def register(self):
        """
        function: register new account
        :return:
        """
        self.command['command_key'] = 'register'
        account = input('please input user name: ').strip()  # get register account
        self.command['command_param'] = account
        self.client.send(json.dumps(self.command).encode('utf-8'))  # send register command to server
        account_exist = self.client.recv(1024).decode()
        if account_exist == 'False':                     # judge this account if exist
            password = self.password_get()
            self.client.send(password.encode('utf-8'))  # send user password to server
            self.client.recv(1024)  # 防粘连
            print("\033[33;0mAccount: [%s] register successful.\033[0m" % account)
        else:
            print("\033[31;0mAccount: [%s] is already exist.\033[0m" % account)
        input('Enter any key to back')

    def password_get(self):
        """
        获取密码,并且hashlib加密
        :return:
        """
        while True:
            password = input('input password: ').strip()
            if len(password) == 0:
                continue
            if self.command['command_key'] == 'register':    # 注册需要确认密码
                while True:
                    password_confirm = input('confirm password: ').strip()
                    if password == password_confirm:
                        break
                    else:
                        print('\033[31;0mPassword is inconsistent...\033[0m')
            m = hashlib.md5()                                # used hashlib.md5 to encryption
            m.update(password.encode('utf-8'))
            password = m.hexdigest()
            return password

    @staticmethod
    def exit_system():
        """
        function: exit ftp system
        :return:
        """
        exit()

    @staticmethod
    def user_choice(menu, menu_dic, *args, **kwargs):
        """
        function: provide one interface for user to choice which channel.
        :param menu: menu
        :param menu_dic: function
        :param args:
        :param kwargs:
        :return:
        """
        exit_flag = False
        while not exit_flag:
            print(menu)  # print current menu
            user_option = input(">>>: ").strip()
            if user_option in menu_dic:
                exit_flag = menu_dic[user_option]()  # call menu function which user choice
            else:
                print("\033[31;1mOption does not exist!\033[0m")
                input("Enter any key to back")

    def system_interactive(self):
        """
        function: provide interactive between user and ftp system.
        :return:
        """
        menu = u'''
            ************** FTP **************
            \033[32;1m
            1.  登入
            2.  注册
            3.  退出
            \033[0m'''
        menu_dict = {
            '1': self.login,
            '2': self.register,
            '3': self.exit_system
        }
        self.user_choice(menu, menu_dict)

    def connect(self, ip, port):
        '''
        连接server
        :param ip:
        :param port:
        :return:
        '''
        self.client.connect((ip, port))

def main():
    ftp = FtpClient()
    ftp.system_interactive()
FTP client

 


 

posted @ 2018-01-12 15:19  徘徊的游鱼  阅读(292)  评论(0编辑  收藏  举报