1、A* 中 g,h函数使用两点之间五次样条的弧长,并且根据需要进行微调。

2、五次多项式的末状态为(l,0,0),所以中间过程两点进行连接的时候是水平的,最好的结果还是应该获得各个轨迹点后,再对该轨迹点进行重新拟合。或者在查找的过程中就已经考虑到中间的状态不是水平的。

3、如果模型为低速模型,那这个轨迹应该是可以使用的。

# -*- coding: utf-8 -*-
"""
Created on Fri Dec 27 11:02:55 2019

@author: leizhen.liu
"""
from scipy.integrate import quad
import numpy as np
import matplotlib.pyplot as plt

class VehicleState:
    def __init__(self,s,l,vpl,apl):
        self.s = s
        self.l = l
        self.vpl = vpl
        self.apl = apl


class trajectoryCost:
    def __init__(self,startVehicleState,endVehicleState,totals,totall,detas,detal):
        self.startVehicleState = VehicleState(0,0,startVehicleState.vpl,startVehicleState.apl)
        self.endVehicleState = VehicleState((endVehicleState.s - startVehicleState.s)*detas,(endVehicleState.l - startVehicleState.l)*detal,endVehicleState.vpl,endVehicleState.apl)
        self.sorg = startVehicleState.s * detas
        self.lorg = startVehicleState.l * detal
        self.id = str(startVehicleState.s + startVehicleState.l * totals) + 'id'+ str(endVehicleState.s + endVehicleState.l * totals)
        self.matP = np.zeros((6,1)) 
        self.arcLength()


    def calQuintic(self):
        s = self.endVehicleState.s
        matS = np.mat([[1,0,0,0,0,0],
                         [0,1,0,0,0,0],
                         [0,0,2,0,0,0],
                         [1,s,s**2,s**3,s**4,s**5],
                         [0,1,2*s,3*s**2,4*s**3,5*s**4],
                         [0,0,2,6*s,12*s**2,20*s**3]])
    
        #ju zhen ni pan duan
        if np.linalg.det(matS)<0.001:
            self.matP = np.mat([0,0,0,0,0,0]).T
            return
    

        matL = np.mat([0,self.startVehicleState.vpl,self.startVehicleState.apl,
                         self.endVehicleState.l,self.endVehicleState.vpl,self.endVehicleState.apl])

        self.matP =  matS.I * matL.T 
                

        
    def f(self,s):
        return  np.sqrt(1+ (self.matP[1] +2*self.matP[2]*s+3*self.matP[3]*s**2+4*self.matP[4]*s**3+5*self.matP[5]*s**4)**2)
        
    def arcLength(self):
        self.calQuintic()
        self.cost,err = quad(self.f,self.startVehicleState.s, self.endVehicleState.s)
        
    def show(self):
        s = list()
        l = list()
        print('stop =',self.endVehicleState.s)
        for i in np.arange(0,self.endVehicleState.s,0.3):
            s.append(i + self.sorg)
            sp = np.mat([1,i,i**2,i**3,i**4,i**5])
            l0 = sp*self.matP

            l1 = np.asarray(l0)
            l.append(l1[0][0] + self.lorg)

        plt.plot(s,l,'-')    
            
        
# -*- coding: utf-8 -*-
"""
Spyder Editor

This is a temporary script file.
"""

import numpy as  np
import matplotlib.pyplot as plt
import math
import queue
import time
import copy
import huchang

MaxsSearch  = 5

detaWidth = 0.3
totalWidth = 6.0   #road width
detaLength = 3.0   
totalLength = 60.0 # road predict length
vehicleHarfWidth = 0.8

# s,l half width,half height /m
obstacle = np.array([[40,4,0.3,0.3],
                     [16,1.5,0.6,0.3],#[30,2.5,0.3,0.3]
                     ])

numberOfL = np.int(totalWidth /detaWidth)
numberOfS = np.int(totalLength /detaLength)

class StrPosition:
    def __init__(self,s,l,vpl =0.0,apl =0.0):
        self.s = s
        self.l = l
        self.vpl = vpl
        self.apl = apl


class StrReachablePosition:
    def __init__(self,parentNode,reachNode):
        self.parentNode = parentNode
        self.reachNode = reachNode
        self.f =999
        self.g =999
        self.h =999
        

def nodeCost(startPosition,endPosition,node2nodeCost):
    begin = startPosition.s + startPosition.l*numberOfS
    end = endPosition.s + endPosition.l*numberOfS
    if startPosition.s > endPosition.s:
        begin = copy.deepcopy(end)
        end = startPosition.s + startPosition.l*numberOfS
        Position =copy.deepcopy(startPosition)
        startPosition = copy.deepcopy(endPosition)
        endPosition = copy.deepcopy(Position)
        
    ikey = str(begin) +'id'+str(end)

    if startPosition.s == 2 and startPosition.l == 11:
        print('ikey',ikey)
        
    
    if node2nodeCost.get(ikey,0) == 0 :
        startState =  huchang.VehicleState(startPosition.s,startPosition.l,startPosition.vpl,startPosition.apl)
        endState = huchang.VehicleState(endPosition.s,endPosition.l,endPosition.vpl,endPosition.apl)
        traject = huchang.trajectoryCost(startState,endState,numberOfS,numberOfL,detaLength,detaWidth)
        node2nodeCost[ikey] = copy.deepcopy(traject)

        
    return node2nodeCost,node2nodeCost[ikey].cost
        
def findInQueue(closeMap,node):
     ikey = node.reachNode.l * numberOfS + node.reachNode.s
     if closeMap.get(ikey,0) == 0:
         return False
     return True


def refreshMap(node,openList):
      ikey = node.reachNode.l * numberOfS + node.reachNode.s
      openList[ikey] = node
      return openList
      
def updateNode(node,openList,closeMap):
    #print("updateNode-------------",node.reachNode.l,node.reachNode.s)
    ikey = node.reachNode.l * numberOfS + node.reachNode.s
    if openList.get(ikey,0) != 0:
        nodet = openList[ikey]
        if(nodet.g > node.g):
            openList[ikey] = node
    else:
         if findInQueue(closeMap,node) == False:
             openList[ikey] = node
    return copy.deepcopy(openList)

def delMap(node,openList):
    ikey = node.reachNode.l * numberOfS + node.reachNode.s
    if openList.get(ikey,0) == 0  or ikey == 0:
        return openList
    del openList[ikey] 
    return openList


def environmentShowGraph(environment,startIndex,endIndex):

    shape = environment.shape
    row = shape[0]
    col= shape[1]   
    os = list()
    ol = list()
    for rowi in range(startIndex,endIndex):
        for coli in range(col):
          x = np.array([0,col-1])
          y = np.array([rowi,rowi])
          x1 = np.array([coli,coli])
          y1 = np.array([0,(row -1)])
          plt.plot(x,y,'b-')
          plt.plot(x1,y1,'b-')
          if environment[rowi][coli] >0.1:
              os.append(coli)
              ol.append(rowi)
    plt.plot(os,ol,'o','r')


def environmentShow(environment,startIndex,endIndex):

    shape = environment.shape
    col= shape[1]   
    os = list()
    ol = list()
    for rowi in range(startIndex,endIndex):
        for coli in range(col):
          x = np.array([0,totalLength])
          y = np.array([rowi*detaWidth,rowi*detaWidth])
          x1 = np.array([coli*detaLength,coli*detaLength])
          y1 = np.array([0,totalWidth])
          plt.plot(x,y,'b-')
          plt.plot(x1,y1,'b-')
          if environment[rowi][coli] >0.1:
              os.append(coli*detaLength)
              ol.append(rowi*detaWidth)
    plt.plot(os,ol,'o','r')



def createObstacle(environment,obstacle,startIndex,endIndex):
    
    eshape = environment.shape
    erow = eshape[0]
    ecol= eshape[1]
    
    oshape = obstacle.shape
    for orowi in  range(oshape[0]):

        start_l = math.floor( (obstacle[orowi][1]-obstacle[orowi][2]) / detaWidth) - startIndexOfL
        end_l = math.ceil((obstacle[orowi][1]+obstacle[orowi][2]) / detaWidth) + startIndexOfL
        
        start_s = math.floor(( obstacle[orowi][0] -obstacle[orowi][3])/detaLength)
        end_s = math.ceil(( obstacle[orowi][0] + obstacle[orowi][3])/detaLength)
        #print("obstacle s range",start_s,end_s)
        for obsl_i in range(start_l,end_l+1):
            for obss_i in range(start_s,end_s+1):
                if obsl_i>=0 and obsl_i < erow and obss_i >= 0 and obss_i< ecol:
                    environment[obsl_i][obss_i] = 1.0
                    
            
def createStopPosition(x,startIndexaOfL,endIndexOfL):
    stopPosition = np.zeros((endIndexOfL - startIndexOfL,2))
    for index in range(endIndexOfL-startIndexaOfL):
        stopPosition[index][0] = x
        stopPosition[index][1] = index
    return stopPosition

def getHCost(startPosition,endPosition,node2nodeCost):
    
     node2nodeCost,cost = nodeCost(startPosition,endPosition,node2nodeCost)
     cost0 = (startPosition.l - endPosition.l)*detaWidth - (startPosition.s - endPosition.s)*detaLength *0.1
     return node2nodeCost,cost + cost0 *0.1

def getGCost(node,endPosition,node2nodeCost):
    node2nodeCost,cost = nodeCost(node.reachNode,endPosition,node2nodeCost)
    cost0 = (startPosition.l - endPosition.l)*detaWidth - (startPosition.s - endPosition.s)*detaLength *0.1
    return node2nodeCost, node.g + cost+0.01*cost0

def getFCost(node):
    return node.g+node.h

def checkcollision(environment,startPosition,endPosition):

    minl = min(startPosition.l,endPosition.l)
    maxl = max(startPosition.l,endPosition.l)
    mins = min(startPosition.s,endPosition.s)
    maxs = max(startPosition.s,endPosition.s)
    for s0 in range(mins,maxs+1):
        for l0 in range(minl,maxl+1):
            if environment[l0][s0] >0.1:
                return False
    return True
            


def reachPoints(openList,environment,lastNode,stopPosition,closeMap, startIndexOfL,endIndexOfL,node2nodeCost):
    print("reachPoints-------------")
    parentNode = lastNode.reachNode
    nexts = parentNode.s + 1
    reachNode = StrPosition(nexts,0)
    node = StrReachablePosition(parentNode,reachNode)
    
    shape = environment.shape
    node.parentNode = parentNode
    print("nexts,s of env",nexts , shape[1],numberOfS,numberOfL)
    maxsearchOfs = min(nexts+MaxsSearch,numberOfS)
    for s in range(nexts,maxsearchOfs):
        for l in range(startIndexOfL,endIndexOfL):
            reachNode.s = s
            reachNode.l =l
            if checkcollision(environment,parentNode,reachNode):
                node.reachNode = reachNode
                node2nodeCost, node.g = getGCost(lastNode,reachNode,node2nodeCost)
                node2nodeCost, node.h = getHCost(reachNode,stopPosition,node2nodeCost)
                node.f = getFCost(node)
                openList = updateNode(node,openList,closeMap)  
    return openList,node2nodeCost
    
def freshCloseList(node,openList,closeList,closeMap):
    print("freshCloseList-------------")
    ikey = node.reachNode.l * numberOfS + node.reachNode.s
    closeList.put(node)
    closeMap[ikey] =node.parentNode.l * numberOfS + node.parentNode.s
    return delMap(node,openList),closeList,closeMap    
     
 
def getMinFFromOpenList(openList):
    print("getMinFFromOpenList-------------")
    #find min node
    start = StrPosition(0,0)
    end = StrPosition(0,0)
    noden = StrReachablePosition(start,end)
    if  len(openList) == 0:
        return noden,False
    minf =999999 
    for nodet in openList.items():
        if minf >nodet[1].f:
            minf = nodet[1].f
            noden = nodet[1]
            
    print("min",minf,noden.reachNode.s,noden.reachNode.l)       
    return noden,True
 


def AStart(environment,startPosition,stopPosition,startIndexOfL,endIndexOfL):
    print("AStart-------------")
    node2NodeCost = dict()
    #put startnode to queue
    node = StrReachablePosition(startPosition,startPosition)
    node.g = 0
    node2NodeCost, node.h = getGCost(node,startPosition,node2NodeCost)
    node.f = getFCost(node)
    openList = dict()
    openList = refreshMap(node,openList)
    print("dict len",len(openList))
    print("reachPosition",node.reachNode.s,node.reachNode.l)
    time.sleep(1)
    closeList = queue.Queue()
    closeMap = dict()
   #search 
    while node.reachNode.s != stopPosition.s or  \
        node.reachNode.l != stopPosition.l and len(openList)!=0:
        # update reachabe node 
        #print("node.reachNode.s,stopPosition.s,node.reachNode.l,stopPosition.l",node.reachNode.s,stopPosition.s,node.reachNode.l,stopPosition.l)
        openList,node2NodeCost = reachPoints(openList,environment,node,stopPosition,closeMap,startIndexOfL,endIndexOfL,node2NodeCost)

        ##-------------
        '''
        if node.reachNode.s ==  2 and node.reachNode.l == 7:
            for nodet  in openList.items():
                print("openlist",nodet[0],nodet[1].reachNode.s,nodet[1].reachNode.l,nodet[1].f)
        '''
        ##----------------
        #time.sleep(5)
        node,flag = getMinFFromOpenList(openList)
        if flag == False:
            print("openlist null")
            return closeList,closeMap,node2NodeCost,False
        else:
           openList,closeList,closeMap = freshCloseList(node,openList,closeList,closeMap)
          ##-------------
        #print("closemaplen",len(closeMap))
        #for nodet2  in closeMap.items():
        #    print("closemap",nodet2[0],nodet2[1])
        ##----------------
             
        if node.reachNode.s == stopPosition.s and node.reachNode.l == stopPosition.l:
            return closeList,closeMap,node2NodeCost,True
    print("openlist null",len(openList))
    return closeList, closeMap,node2NodeCost,False     
  
    
def getTrack(closeMap,startPosition,stopPosition):
    startkey =  startPosition.l * numberOfS + startPosition.s
    endkey = stopPosition.l * numberOfS + stopPosition.s
    
    ikey = closeMap[endkey]
    slist = list()
    llist = list()
    track = queue.deque() 
    slist.append(stopPosition.s)
    llist.append(stopPosition.l)
    track.append(np.int(endkey))
    #print('s,l',endkey,stopPosition.s,stopPosition.l)
    
    while ikey != startkey:
        s =  np.int(ikey % numberOfS)
        l =  np.int((ikey - s) / numberOfS)
        #print('s,l',ikey,s,l)
        track.append(np.int(ikey))
        ikey = closeMap[ikey]
        slist.append(s)
        llist.append(l)
        
    slist.append(startPosition.s)
    llist.append(startPosition.l)
    track.append(np.int(startkey))
    #print('s,l',startkey,startPosition.s,startPosition.l)
    
    return copy.deepcopy(slist),copy.deepcopy(llist),copy.deepcopy(track)

def showTrajackPointGraph(slist,llist):
    plt.plot(slist,llist,'o')
    

def showTrajackPoint(slist,llist):
    slistn  = list()
    llistn = list()
    for s in slist:
        slistn.append(s*detaLength)
        print('s',s*detaLength)
    for l in llist:
        llistn.append(l*detaWidth)
        print('l',l*detaWidth)
        
    plt.plot(slistn,llistn,'o')   

def showTrack(track,node2nodeCost):
    start = track.pop()
    print('node2nodeCost',len(node2nodeCost))
    while len(track) !=0:
        end = track.pop()
        key = str(start) + 'id' + str(end)
        #print('key',start,end) 
        trackvalue = node2nodeCost[key]
        trackvalue.show()
        start = end
    
    
def showCloseList(closeList):
    while not closeList.empty():
        node = closeList.get()
        plt.plot(node.reachNode.s*detaLength,node.reachNode.l*detaWidth,'r*')

def showCloseListGraph(closeList):
    while not closeList.empty():
        node = closeList.get()
        plt.plot(node.reachNode.s,node.reachNode.l,'r*')   
    

if __name__ == '__main__':

    # vehicle to road region  of l 

    startIndexOfL = np.int(vehicleHarfWidth / detaWidth)
    endIndexOfL = numberOfL - np.int(vehicleHarfWidth / detaWidth)
     
    environment = np.zeros((numberOfL,numberOfS))
    createObstacle(environment,obstacle,startIndexOfL,endIndexOfL)
    environmentShow(environment,startIndexOfL,endIndexOfL)
    
    #simple end a*--------------------------------------------

    startPosition = StrPosition(0,12,0.1,0)
    stopPosition = StrPosition(18,8)
    closeList,closeMap,node2NodeCost,Flag = AStart(environment,startPosition,stopPosition,startIndexOfL,endIndexOfL)
    
    
    
    if Flag == True:
        slist,llist,track = getTrack(closeMap,startPosition,stopPosition)
        
        for t in track:
            print('result track',t)
        
        
        showTrajackPoint(slist,llist)
        showTrack(track,node2NodeCost)
    else:
        print("no route-----------------")
        showCloseList(closeList)

    #multi end a*------------------------------------
    '''
    stopPosition = createStopPosition(numberOfS-1,startIndexOfL,endIndexOfL)  
    startPosition = StrPosition(0,9)

    shapeOfEnd = stopPosition.shape
    for i  in  range(0,shapeOfEnd[0]):
        stop = StrPosition(stopPosition[i][0],stopPosition[i][1])
        closeList,closeMap,node2NodeCost,Flag = AStart(environment,startPosition,stop,startIndexOfL,endIndexOfL)
    
        if Flag == True:
            slist,llist,track = getTrack(closeMap,startPosition,stop)   
            for t in track:
                print('result track',t)
        
            showTrajackPoint(slist,llist)
            showTrack(track,node2NodeCost)
        else:
            print("no route-----------------")
            showCloseList(closeList)
        #showCloseList(closeList)   
    '''

    




          
    
        
    

 

posted on 2019-12-28 17:11  卡贝天师  阅读(437)  评论(0编辑  收藏  举报