返回顶部
扩大
缩小

6. 12306火车购票系统 (车票和列车数据的获取)

day 2019-7-9

实现的功能

  • 登录
  • 注册
  • 车票查询接口
  • 列车查询接口
  • 权限、频率校验

路由配置 urls.py

import xadmin
from xadmin.plugins import xversion
xadmin.autodiscover()
xversion.register_models()
from django.conf.urls import url
# from django.contrib import admin
from app01 import views
urlpatterns = [
    url(r'^xadmin/', xadmin.site.urls),              # xadmin
    url(r'^register/', views.Register.as_view()),    #注册
    url(r'^login/', views.Login.as_view()),          # 登录
    url(r'^index/', views.Index.as_view()),          # 主页
    url(r'^otn/', views.Otn.as_view()),              # 查询车票数据  新增
    url(r'^train_size/', views.Train_all.as_view()), # 查询列车数据  新增
]

视图 views.py

import time

from django.contrib import auth
from django.shortcuts import HttpResponse
from rest_framework.response import Response
# Create your views here.
from rest_framework.views import APIView

from app01 import models
from app01.myfile.Otn import Otn_query
from app01.myfile.my_configuration import RegSerializer
from app01.myfile.my_configuration import TokenAuth
from app01.myfile.tokens import Token

# 注册
class Register(APIView):
    pass
    # 设置此参数后不需要校验
    authentication_classes = []

    def post(self, request):
        func_dic = {'code': '200', 'msg': '', "data": ''}
        res = RegSerializer(data=request.data)

        if res.is_valid():
            res.create(res.data)
            func_dic['data'] = res.data
            return Response(func_dic)
        return Response(res.errors)

# 登录
class Login(APIView):

    def post(self, request):
        func_dic = {'code': '200', 'msg': '', 'data': ''}
        username = request.data.get('username')
        password = request.data.get('password')
        user = models.UserInfo.objects.filter(username=username).first()
        if user:
            user_obj = auth.authenticate(username=username, password=password)

            if user_obj:
                pass
                # 获取token
                token = Token().create_token(username)
                func_dic['msg']='登录成功!'
                func_dic['token']='token!'

            else:
                func_dic['code'] = '402'
                func_dic['msg'] = '密码错误!'
        else:
            func_dic['code'] = '401'
            func_dic['msg'] = '用户不存在!'
        return Response(func_dic)


# 主页面
class Index(APIView):
    authentication_classes = [TokenAuth, ]

    def get(self, request):
        return HttpResponse('False')


# 查询票
class Otn(APIView):
    def get(self, request):
        func_dic = {'code': '200', 'msg': '', "data": ''}
        from_station = request.GET.get('from_station')
        to_station = request.GET.get('to_station')
        # 查询所有车次列表
        res = time.time()
        data = Otn_query().query(from_station, to_station)

        # 查询车次信息
        train_data = Otn_query().train(data, data[1])
        # 座位

        # 价格
        print(time.time() - res)
        if data:
            func_dic['data'] = data
            func_dic['train_data'] = train_data
            func_dic['len'] = len(train_data)
        return Response(func_dic)


# 获取所有列车
class Train_all(APIView):
    def post(self, request):
        train_size = request.data.get('train_size')
        func_dic = {}
        if train_size:
            func_dic = Otn_query().train_all(train_size)
        return Response(func_dic)

登录中token tokens.py

优点:无需存储,使用 signing 直接校验,安全级别高,不占内存资源

import hashlib
import time

from django.core import signing


class Token:
    def __init__(self):
        self.HEADER = {'typ': 'get_token', 'alg': 'default'}
        self.KEY = 'LI QIANG'
        self.SALT = 'www.liqianglog.top'  # 加盐
        self.TIME_OUT = 30*60  # 30分钟失效

    def encrypt(self, obj):
        """
        加密
        :param obj:
        :return:
        """
        value = signing.dumps(obj, key=self.KEY, salt=self.SALT)
        value = signing.b64_encode(value.encode()).decode()
        return value

    def decrypt(self, src):
        """
        解密
        :param src:
        :return:
        """
        src = str(src).split('.')[1]

        src = signing.b64_decode(src.encode()).decode()
        src = signing.loads(src, key=self.KEY, salt=self.SALT)
        return src

    def create_token(self, username):
        """
        生成token信息
        :param username:
        :return:
        """
        # 1.生成加密头信息
        header = self.encrypt(self.HEADER)
        # 2.构件payload
        payload = {'username': username, 'iat': time.time()}
        payload = self.encrypt(payload)
        # 3.生成签名
        md5 = hashlib.md5()
        md5.update(("{}.{}".format(header, payload)).encode())
        signture = md5.hexdigest()
        token = "{}.{}.{}".format(header, payload, signture)

        return token

    def check_token(self, token,get_username):
        """
        # 解密取值,校验日期,校验用户名,如果没有失效就再次生成token,失效就返回False

        :param token:
        :return:
        """
        try:
            payload = self.decrypt(token)
        except:
            return False
        old_time = payload.get('iat', None)
        if time.time() - old_time > self.TIME_OUT:
            return False
        username = payload.get('username', None)
        if username == get_username:
            return True
        else:
            return False

车票和列车获取 Oth.py

from app01 import models
import time
class Otn_query:

    def query(self,from_station,to_station):

        # 查询车站是否合法
        res = time.time()
        from_res = models.Station.objects.filter(english=from_station)
        to_res = models.Station.objects.filter(english=to_station)
        map={'from_station':[],'to_station':[]}
        if not from_res and to_res:
            return False

        # 查询出发站和到达站所在城市的所有站名

        # 查询出发车站所在城市所有站名
        from_list = []
        station_obj = models.Station.objects.filter(english=from_station).first()
        station_objs = models.Station.objects.filter(city=station_obj.city).all().values('id','english')
        for english in station_objs:
            from_list.append(english.get('id'))
            map['from_station'].append(english.get('english'))

        # 查询到达车站所在城市所有站名
        to_list = []
        station_obj = models.Station.objects.filter(english=to_station).first()
        station_objs = models.Station.objects.filter(city=station_obj.city).all().values('id','english')
        for english in station_objs:
            to_list.append(english.get('id'))
            map['to_station'].append(english.get('english'))


        # 查询经过出发站的所有车
        from_train_dic = {}
        for station in from_list:
            ttation2train_obj = models.Station2Train.objects.filter(station=station).values_list('train__train_size','station_next')
            for train in ttation2train_obj:
                from_train_dic[train[0]]= train[1]

        # 查询经过出发站的所有车
        to_train_dic = {}
        for station in to_list:
            ttation2train_obj = models.Station2Train.objects.filter(station=station).values_list('train__train_size',
                                                                                                 'station_next')
            for train in ttation2train_obj:
                to_train_dic[train[0]] = train[1]
        # 两个地点都经过的站,也就是从出发站到到达站的所有车次
        # 获取方向,从站次入手,顺序为正数,逆序为负数
        ok_list =[]
        for train_size,station_next in from_train_dic.items():
            if train_size in to_train_dic:
                number = int(to_train_dic[train_size]) - int(station_next)
                if number>0:
                    ok_list.append(train_size)
        # 查询所有车信息

        return ok_list,map
    # 查询详细站信息
    def train(self,train_size_list,station_dic):
        """
        需要查询信息:
            车次代码、出发时间、到达时间、出发站、到达站、座位信息、历时、价格

        循环查询出所需车次信息
        :param lists:
        :return:
        """
        func_data = []
        for train_size in train_size_list[0]:
            # 站名代码、出发时间、到达时间、出发站、到达站
            data = models.Station2Train.objects.filter(train__train_size=train_size).values_list('station__english','station__station_name','arrive_time','depart_time','price')
            depart_time = arrive_time = start_stand = terminus =''
            section=[]  # 用来取区间位置
            price=0
            for i,data_list in enumerate(data,1):
                # 历时
                if data_list[0] in station_dic['from_station']:
                    start_stand = data_list[1] # 出发站
                    depart_time = data_list[3] # 出发时间
                    price = data_list[4] # 出发时间
                    section.append(i)
                if data_list[0] in station_dic['to_station']:
                    terminus = data_list[1]    # 到达站
                    arrive_time = data_list[2] # 到达时间
                    price = float(data_list[4]) - float(price) # 价格
                    section.append(i)


            take = self.times(depart_time, arrive_time)
            # 获取座位数
            seat_number = self.seat_type(train_size,section)
            train_1 = price * 3     #  商务座 * 3
            train_2 = price * 2     # 一等座 * 2
            train_3 = price * 1     # 二等座 * 1
            train_4 = price * 5     # 高级软卧 * 4
            train_5 = price * 4     # 高级硬卧 * 4
            train_6 = price * 1     # 硬座 * 1
            train_7 = price * 1     # 无座 * 1
            info = "{:g}|{:g}|{:g}|{:g}|{:g}|{:g}|{:g}|".format(train_1,train_2,train_3,train_4,train_5,train_6,train_7)
            # 返回
            res = "{}|{}|{}|{}|{}|{}|{}|{}".format(train_size,depart_time,arrive_time,start_stand,terminus,take,seat_number,info)
            func_data.append(res)
            """
            需要对应站的座位
            """
        return func_data

    # 查询列车
    def train_all(self,train_size):
        """
         需要查询信息:
            站名、始发时间、到达时间、是起终停、停留时间
        :param train_size:
        :return:
        """

        # 站名代码、出发时间、到达时间、出发站、到达站
        train_obj = models.Train.objects.filter(train_size=train_size).values_list('start_stand__station_name','terminus__station_name').first()

        data = models.Station2Train.objects.filter(train__train_size=train_size).values_list('station_next',
                                                  'station__station_name','arrive_time','depart_time','is_state')
        res = [train_obj[0],train_obj[1]]
        func_dic = {train_size:[res,[]]}
        for i, data_list in enumerate(data, 1):

            start_stand = data_list[0]  # 序号
            station_name = data_list[1]  # 站名
            arrive_time = data_list[2]  # 到站时间
            depart_time = data_list[3]  # 出发时间
            is_state = data_list[4]  # 出发时间
            take = self.times(arrive_time,depart_time)  # 历时
            func_dic[train_size][1].append("{}|{}|{}|{}|{}|{}".format(start_stand,station_name,arrive_time,depart_time,take,is_state))

        print(func_dic)
        """
        需要对应站的座位
        """
        return func_dic


    # 获取 历时
    def times(self,str_time_1,str_time_2):
        """
        # 获取 历时
        :param str_time_1:
        :param str_time_2:
        :return:
        """
        try:
            hours_1 = str_time_1.split(':')[0]
            minutes_1 = str_time_1.split(':')[1]
            hours_2 = str_time_2.split(':')[0]
            minutes_2 = str_time_2.split(':')[1]

            hours = int(hours_2)-int(hours_1)
            minutes = int(minutes_2)-int(minutes_1)

        except IndexError:
            return "--"
        if hours==0 and minutes==0:
            return '--'
        elif hours>0 :
            if minutes < 0:
                minutes = 60 + minutes
                hours -= 1
                if hours==0:
                    return "{}分钟".format(minutes)
            return "{}小时{}分钟".format(hours,minutes)
        else:
            return "{}分钟".format(minutes)

    # 查询座位数
    def seat_type(self,train_size,section):
        """

        :param train_size:  列车号
        :param section:  区间位置
        :return:
        """
        data = models.Seat.objects.filter(train__train_size=train_size).values_list('is_sell','seat_type')
        # 循环取座位段
        number_list =['0','0','0','0','0','0','0']
        for is_sell in data:
            if '0' in is_sell[0][section[0]:section[1]]:
                continue
            index = is_sell[1]-1
            number_list[index] = str(int(number_list[index]) +1)
        return "|".join(number_list)

全局配置:访问评率 settings.py

# 设置访问频率,
REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES':['app01.myfile.my_configuration.VisitThrottle',],
    'DEFAULT_THROTTLE_RATES':{
        '12306':'20/m'
    }
}

自定义配置文件 my_configuration.py

from rest_framework.serializers import ModelSerializer
from app01 import models
from rest_framework import exceptions
from app01.myfile.tokens import Token
# 注册序列化
class RegSerializer(ModelSerializer):
    class Meta():
        model = models.UserInfo
        fields = ('username','password','email')

    def validate_username(self, data):
        if 'sb' in data:
            raise exceptions.ValidationError('内容包含敏感词汇!')
        return data
    def create(self, validated_data):
        models.UserInfo.objects.create_user(**validated_data)
        return validated_data

# 用户校验
from rest_framework.authentication import BaseAuthentication
class TokenAuth(BaseAuthentication):
    def authenticate(self, request):
        token = request.data.get('token')
        username = request.data.get('username')
        token_obj = Token().check_token(token,username)
        if token_obj:

            return
        else:
            raise exceptions.AuthenticationFailed('认证失败')
    def authenticate_header(self,request):
        pass


# 访问频率限制
from rest_framework.throttling import SimpleRateThrottle
class VisitThrottle(SimpleRateThrottle):
    scope = '12306'
    def get_cache_key(self, request, view):
        return self.get_ident(request)

posted on 2019-07-09 14:14  代码创造一切R  阅读(1196)  评论(0编辑  收藏  举报

导航