A*算法 -- 八数码问题和传教士过河问题的代码实现

  前段时间人工智能的课介绍到A*算法,于是便去了解了一下,然后试着用这个算法去解决经典的八数码问题,一开始写用了挺久时间的,后来试着把算法的框架抽离出来,编写成一个通用的算法模板,这样子如果以后需要用到A*算法的话就可以利用这个模板进行快速开发了(对于刷OJ的题当然不适合,不过可以适用于平时写一些小游戏之类的东西)。

  A*算法的原理就不过多介绍了,网上能找到一大堆,核心就是估价函数 g() 的定义,这个会直接影响搜索的速度,我在代码里使用 C++/Java 的多态性来编写业务无关的算法模板,用一个抽象类来表示搜索树中的状态,A*算法主类直接操纵这个抽象类,然后编写自己业务相关的类去继承这个抽象类并实现其中的所有抽象方法(C++里是纯虚函数),之后调用A*算法主类的 run 函数就能得到一条可行并且是最短的的搜索路径,下面具体看代码:(文末附所有代码的 github 地址)

先看 c++ 部分,毕竟一开始就是用 c++ 来写的

首先是表示状态的抽象基类CState,头文件 state.h:

#ifndef  __state_h
#define  __state_h

#include <cstddef>
#include <vector>
using std::vector;

class CState
{
public:
    CState();
    virtual bool operator < (const CState &) const=0;
    virtual void checkSomeFields(const CState &) const;
    virtual vector<CState*> getNextState() const=0;
    vector<CState*> __getNextState() const;        // call the function getNextState and deal with iSteps and pparent
    virtual long astar_f() const;
    virtual long astar_g() const=0;        // g函数的值越小,优先级就越高,f()和h()函数类似
    virtual long astar_h() const;
    virtual ~CState();

    int iSteps;
    const CState *pparent;        // 必须指向实实际际存在的值!注意不要指向一个局部变量等!
};

#endif

源文件 state.cpp:

#include "state.h"
#include <algorithm>
using std::for_each;

CState::CState(): iSteps(0), pparent(NULL) {}

void CState::checkSomeFields(const CState &) const {}

vector<CState*> CState::__getNextState() const
{
    vector<CState*> nextState = getNextState();
    for_each(nextState.begin(), nextState.end(), [this](CState *pstate) {
        pstate->iSteps = this->iSteps + 1;
        pstate->pparent = this;
    });
    return nextState;
}

long CState::astar_f() const
{
    return iSteps;
}

long CState::astar_h() const
{
    return astar_f() + astar_g();
}

CState::~CState() {}

  子类只需实现小于运算符,getNextState(),astar_g() 这三个纯虚函数就可以了,另外几个虚函数可以不重写,直接用父类的即可。

  然后是A*算法主类 CAstar,头文件 astar.h:

#ifndef  __ASTAR_H
#define  __ASTAR_H

#include "state.h"
#include <set>
using std::set;

class CAstar
{
public:
    CAstar(const CState &_start, const CState &_end);
    static set<const CState*> getStateByStartAndSteps(const CState &start, int steps);
    void run();
    ~CAstar();

    const CState &m_rStart, &m_rEnd;
    bool bCanSolve;
    int iSteps;
    vector<const CState*> vecSolve;
    long lRunTime;
    int iTotalStates;
private:
    set<const CState*> pointerWaitToDelete;
};

#endif

源文件 astar.cpp:

#include "astar.h"
#include "timeval.h"
#include "exception.h"
#include <set>
#include <queue>
#include <algorithm>
#include <cstdlib>
#include <functional>
using std::set;
using std::queue;
using std::priority_queue;
using std::swap;
using std::max;
using std::sort;
using std::function;
#define  For(i,s,t)  for(auto i = (s); i != (t); ++i)

CAstar::CAstar(const CState &_start, const CState &_end):
    m_rStart(_start), m_rEnd(_end), bCanSolve(false), iSteps(0), vecSolve{},
    iTotalStates(0), lRunTime(0), pointerWaitToDelete{}
{
    m_rStart.checkSomeFields(m_rEnd);
}

template <typename T>
struct CPointerComp
{
    bool operator () (const T &pl, const T &pr) const
    {
        return *pl < *pr;
    }
};

set<const CState*> CAstar::getStateByStartAndSteps(const CState &start, int steps)
{
    set<const CState*> retSet;
    set<const CState*, CPointerComp<const CState*> > inSet;
    inSet.insert(&start);
    queue<const CState*> queState;
    queState.push(&start);
    while(!queState.empty()) {
        const CState* const pCurState = queState.front();
        queState.pop();
        if(pCurState->iSteps > steps) {
            continue;
        }
        if(pCurState->iSteps == steps) {
            retSet.insert(pCurState);
            continue;
        }
        auto nextState = pCurState->__getNextState();
        int len = nextState.size();
        For(i, 0, len) {
            if(inSet.find(nextState[i]) == inSet.end()) {
                queState.push(nextState[i]);
                inSet.insert(nextState[i]);
            } else {
                delete nextState[i];
            }
        }
    }
    inSet.erase(&start);
    For(ret_it, retSet.begin(), retSet.end()) {
        inSet.erase(*ret_it);
    }
    For(ins_it, inSet.begin(), inSet.end()) {
        delete *ins_it;
    }
    return retSet;
}

struct priority_state
{
    bool operator () (const CState* const lhs, const CState* const rhs) const
    {
        return lhs->astar_h() > rhs->astar_h();
    }
};

void CAstar::run()
{
    CTimeVal _time;

    set<const CState*, CPointerComp<const CState*>> setState;
    setState.insert(&m_rStart);
    priority_queue<const CState*, vector<const CState*>, priority_state> queState;
    queState.push(&m_rStart);

    while(!queState.empty()) {
        // auto pHeadState = *(setState.find(queState.top()));
        auto pHeadState = queState.top();
        queState.pop();
        if(!(*pHeadState < m_rEnd) && !(m_rEnd < *pHeadState)) {
            bCanSolve = true;
            iSteps = pHeadState->iSteps;
            vecSolve.push_back(pHeadState);
            const CState *lastState = pHeadState->pparent;
            while(lastState != NULL) {
                vecSolve.push_back(lastState);
                lastState = lastState->pparent;
            }
            break;
        }
        vector<CState*> nextState = pHeadState->__getNextState();
        int len = nextState.size();

        for(int i = 0; i < len; ++i) {
            auto state_it = setState.find(nextState[i]);
            if(state_it == setState.end()) {
                queState.push(nextState[i]);
                setState.insert(nextState[i]);
            } else {
                if((*state_it)->astar_f() > nextState[i]->astar_f()) {
                    pointerWaitToDelete.insert(*state_it);        // 这一句要放在setState.erase前面,防止迭代器失效
                    setState.erase(state_it);
                    setState.insert(nextState[i]);
                    queState.push(nextState[i]);
                } else {
                    delete nextState[i];
                }
            }
        }
        if(setState.size() > 6000 * 10000) {
            break ;
        }
    }
    iTotalStates = setState.size();
    lRunTime = _time.costTime();
    setState.erase(&m_rStart);
    For(vec_it, vecSolve.begin(), vecSolve.end()) {
        setState.erase(*vec_it);
    }
    For(s_it, setState.begin(), setState.end()) {
        delete *s_it;
    }
}

CAstar::~CAstar()
{
    For(vec_it, vecSolve.begin(), vecSolve.end()) {
        if(*vec_it != &m_rStart && *vec_it != &m_rEnd) {
            delete *vec_it;
        }
    }
    for(const auto &pState: pointerWaitToDelete) {
        delete pState;
    }
}

   主搜索函数里是以 广度优先搜索 + 优先队列 来实现A*算法的,因为是用多态来实现,用到了指针,所以有些细节可能写得不是很好看,但是经运行测试过没有明显的bug,cpu和内存的使用均在正常的范围内。

  以上两个类就是A*算法的主体框架了,但里面用到了自定义的异常类 CException 和计时类 CTimeVal 等一些工具类,具体代码可以在后面的 github 地址里看到。

  然后是业务相关的类,这里首先是八数码问题的类 CChess,头文件 chess.h:

#ifndef  __CCHESS_H
#define  __CCHESS_H

#include "state.h"
#include <iostream>
#include <string>
#include <vector>
using std::string;
using std::vector;
using std::ostream;


class CChess: public CState
{
    friend ostream& operator << (ostream &, const CChess &);
    static int iLimitNum;
public:
    CChess(const string &state, int row, int col, const string &standard="");

    virtual bool operator < (const CState &) const;
    virtual void checkSomeFields(const CState &) const;

    const string& getStrState() const;
    void setStrStandard(const string &);
    virtual vector<CState*> getNextState() const;
    // virtual long astar_f() const;
    virtual long astar_g() const;
    // virtual long astar_h() const;

private:
    void check_row_col() const;
    void check_value() const;
    void check_standard() const;
    inline int countNotMatch() const;
    inline int countLocalNotMatch(int, int) const;

private:
    string strState;
    int iRow, iCol;
    int iZeroIdx;
    string strStandard;
    int iNotMatch;

public:
    int iMoveFromLast;
    static const string directs[5];
    enum DIRECT
    {
        UP, DOWN, LEFT, RIGHT, UNKOWN
    };
    void output(ostream &out, const string &colSpace=" ", const string &rowSpace="\n") const;
};

#endif
chess.h

源文件 chess.cpp:

#include "chess.h"
#include "exception.h"
#include <cstring>
#include <algorithm>
using std::sort;
using std::swap;
#define  For(i,s,t)  for(auto i = (s); i != (t); ++i)

int CChess::iLimitNum = 20;
const string CChess::directs[5] = {"up", "down", "left", "right", "unkown"};

void CChess::check_row_col() const
{
    if(iRow <= 0 || iCol <= 0) {
        throw CException(1001, "行或列的值不能小于等于0!");
    }
    if(iRow * iCol > iLimitNum) {
        char msg[100];
        sprintf(msg, "行列数的乘积不能超过%d!", iLimitNum);
        throw CException(1002, msg);
    }
    if(iRow * iCol != strState.size()) {
        throw CException(1003, "行列数的乘积应该和字符串的长度相等!");
    }
}

void CChess::check_value() const
{
    if(iZeroIdx == string::npos) {
        throw CException(1004, "字符串值不合法,应该含有'0'!");
    }
    bool ch[300];
    memset(ch, 0, sizeof(ch));
    int len = strState.size();
    for(int i = 0; i < len; ++i) {
        if(ch[strState[i]] == true) {
            throw CException(1005, "字符串中不能含有相同的字符!");
        }
        ch[strState[i]] = true;
    }
}

void CChess::check_standard() const
{
    int len = strState.size();
    int len2 = strStandard.size();
    if(len != len2) {
        throw CException(1006, "目标状态的字符长度与当前状态的字符长度不等!");
    }
    bool origin[300];
    memset(origin, false, sizeof origin);
    For(i, 0, len) {
        origin[strState[i]] = true;
    }
    bool standard[300];
    memset(standard, false, sizeof standard);
    For(i, 0, len) {
        standard[strStandard[i]] = true;
    }
    For(i, 0, 300) {
        if(origin[i] != standard[i]) {
            throw CException(1007, "目标状态的字符内容与当前状态的字符内容不等!");
        }
    }
}

CChess::CChess(const string &state, int row, int col, const string &standard):
    strState(state), iRow(row), iCol(col), CState(),
    iMoveFromLast(UNKOWN), strStandard(standard)
{
    check_row_col();
    iZeroIdx = strState.find('0');
    check_value();
    if(strStandard == "") {
        strStandard = strState;
        sort(strStandard.begin(), strStandard.end());
    }
    check_standard();
    iNotMatch = countNotMatch();
}

void CChess::checkSomeFields(const CState &rhs) const
{
    if(iRow != ((CChess*)&rhs)->iRow) {
        throw CException(2001, "开始字符串和结束字符串的行不相同!");
    }
    if(iCol != ((CChess*)&rhs)->iCol) {
        throw CException(2002, "开始字符串和结束字符串的列不相同!");
    }
    auto tmp_this = strState;
    auto tmp_rhs = ((CChess*)&rhs)->strState;
    sort(tmp_this.begin(), tmp_this.end());
    sort(tmp_rhs.begin(), tmp_rhs.end());
    if(tmp_this != tmp_rhs) {
        throw CException(2003, "开始字符串和结束字符串含有的字符有差别!");
    }
}

bool CChess::operator < (const CState &rhs) const
{
    const auto &r_str = ((CChess*)&rhs)->strState;
    int cmp = strcmp(strState.c_str(), r_str.c_str());
    const auto &r_row = ((CChess*)&rhs)->iRow;
    const auto &r_col = ((CChess*)&rhs)->iCol;
    if(cmp == 0) {
        if(iRow == r_row)   return iCol < r_col;
        return iRow < r_row;
    }
    return cmp < 0;
}

const string& CChess::getStrState() const
{
    return strState;
}

void CChess::setStrStandard(const string &standard)
{
    strStandard = standard;
    check_standard();
    iNotMatch = countNotMatch();
}

int CChess::countNotMatch() const
{
    int notMatch = 0;
    For(i, 0, iRow) {
        For(j, 0, iCol) {
            if(strState[i * iCol + j] != strStandard[i * iCol + j]) {
                ++notMatch;
            }
        }
    }
    return notMatch;
}

int CChess::countLocalNotMatch(int one, int two) const
{
    int oldNotMatch = 0;
    if(strState[two] != strStandard[one]) {
        ++oldNotMatch;
    }
    if(strState[one] != strStandard[two]) {
        ++oldNotMatch;
    }
    int nowNotMatch = 0;
    if(strState[one] != strStandard[one]) {
        ++nowNotMatch;
    }
    if(strState[two] != strStandard[two]) {
        ++nowNotMatch;
    }
    return this->iNotMatch - oldNotMatch + nowNotMatch;
}

vector<CState*> CChess::getNextState() const
{
    vector<CState*> nextChess;
    // 0上面存在数字,可以下移
    if(iZeroIdx >= iCol) {
        CChess *down = new CChess(*this);
        swap(down->strState[iZeroIdx - iCol], down->strState[iZeroIdx]);
        down->iNotMatch = down->countLocalNotMatch(iZeroIdx - iCol, iZeroIdx);
        down->iZeroIdx -= iCol;
        down->iMoveFromLast = CChess::DOWN;
        nextChess.push_back(down);
    }
    if(iZeroIdx < strState.size() - iCol) {
        CChess *up = new CChess(*this);
        swap(up->strState[iZeroIdx + iCol], up->strState[iZeroIdx]);
        up->iNotMatch = up->countLocalNotMatch(iZeroIdx + iCol, iZeroIdx);
        up->iZeroIdx += iCol;
        up->iMoveFromLast = CChess::UP;
        nextChess.push_back(up);
    }
    if(iZeroIdx % iCol != 0) {
        CChess *right = new CChess(*this);
        swap(right->strState[iZeroIdx - 1], right->strState[iZeroIdx]);
        right->iNotMatch = right->countLocalNotMatch(iZeroIdx - 1, iZeroIdx);
        --right->iZeroIdx;
        right->iMoveFromLast = CChess::RIGHT;
        nextChess.push_back(right);
    }
    if((iZeroIdx + 1) % iCol != 0) {
        CChess *left = new CChess(*this);
        swap(left->strState[iZeroIdx + 1], left->strState[iZeroIdx]);
        left->iNotMatch = left->countLocalNotMatch(iZeroIdx + 1, iZeroIdx);
        ++left->iZeroIdx;
        left->iMoveFromLast = CChess::LEFT;
        nextChess.push_back(left);
    }
    return nextChess;
}

long CChess::astar_g() const
{
    return iNotMatch;
}

void CChess::output(ostream &out, const string &colSpace, const string &rowSpace) const
{
    for(int i = 0; i < iRow; ++i) {
        for(int j = 0; j < iCol; ++j) {
            out << strState[i * iCol + j];
            if(j != iCol - 1) {
                out << colSpace;
            }
        }
        if(i != iRow - 1) {
            out << rowSpace;
        }
    }
}

std::ostream& operator << (std::ostream &out, const CChess &oChess)
{
    oChess.output(out);
    out << "\n";
    return out;
}
chess.cpp

  八数码问题当时是花了挺久时间做了很大的优化的,最后是main函数,用于简单的交互功能:

#include "chess.h"
#include "exception.h"
#include "astar.h"
#include "timeval.h"
#include <unistd.h>
#include <string.h>
#include <iostream>
using namespace std;

int main(int argc, char const *argv[])
{
    string str;
    int r, c;
    while(true) {
        try {
            cout << "please input the start state(string) and row, col, separate with a space:\n";
            if(bool(cin >> str >> r >> c) == false) {
                break;
            }
            CChess start(str, r, c);
            cout << "please input the end state(string) and row, col, separate with a space:\n";
            if(bool(cin >> str >> r >> c) == false) {
                break;
            }
            CChess end(str, r, c);
            start.setStrStandard(str);

            cout << "Your game is:\n" << start << " --->\n" << end << "\n";

            CAstar game(start, end);
            game.run();

            if(game.bCanSolve == true) {
                cout << "your game can be solve:\n";
                cout << "the total steps is: " << game.iSteps << "\n";
                cout << "and the path is:\n";
                int len = game.vecSolve.size();
                cout << *((CChess*)(game.vecSolve[len - 1])) << "\n";
                for(int i = len - 2; i >= 0; --i) {
                    cout << "    |\n";
                    cout << "    |   " << CChess::directs[((CChess*)(game.vecSolve[i]))->iMoveFromLast] << "\n";
                    cout << "   \\|/\n\n";
                    cout << *((CChess*)(game.vecSolve[i])) << "\n";
                }
            } else {
                cout << "sorry, your game can't be solve, please input the other state.\n\n";
            }
            cout << " and the max states is: " << game.iTotalStates << "\n";
            cout << " and the runtime is: " << game.lRunTime << "\n";
            
        } catch (const CException &ex) {
            cerr << ex.code << ": " << ex.msg << "\n";
        } catch (...) {
            break;
        }
    }
    return 0;
}
main.cpp

  还写了个用于生成测试用例的程序:

#include "astar.h"
#include "chess.h"
#include "exception.h"
#include <string>
#include <iostream>
#include <sstream>
#include <cctype>
#include <algorithm>
using namespace std;

void usage(const string &exe_name)
{
    string echo = "Usage: " + exe_name + " start(string) row(positive int) col(positive int) steps(positive int).";
    cout << echo << "\n";
}

template <typename T>
T strTo(const string &str)
{
    stringstream ss;
    ss << str;
    T ret;
    ss >> ret;
    return ret;
}

int main(int argc, char const *argv[])
{
    if(argc < 5) {
        usage(argv[0]);
        exit(1);
    }
    int row = strTo<int>(argv[2]);
    int col = strTo<int>(argv[3]);
    if(!row || !col) {
        usage(argv[0]);
        exit(2);
    }
    int steps = strTo<int>(argv[4]);
    string strStandard(argv[1]);
    bool outputOne = true;
    if(argc >= 6) {
        outputOne = false;
    }
    try {
        CChess start(strStandard, row, col);
        auto setChess = CAstar::getStateByStartAndSteps(start, steps);
        if(setChess.size() == 0) {
            throw CException(3001, "走不了这么多步!");
        }
        if(outputOne == true) {
            auto first = *setChess.begin();
            ((CChess*)first)->output(cout);
            cout << "\t\t" << ((CChess*)first)->getStrState() << "\n";
        }
        else {
            cout << "setChess.size() = " << setChess.size() << "\n";
            for_each(setChess.begin(), setChess.end(), [](const CState *elem) {
                ((CChess*)elem)->output(cout);
                cout << "\t\t" << ((CChess*)elem)->getStrState() << "\n";
            });
        }
        for_each(setChess.begin(), setChess.end(), [](const CState *elem){
            delete elem;
        });
    } catch(const CException &ex) {
        cerr << ex.code << ": " << ex.msg << "\n";
        exit(3);
    } catch(...) {
        cerr << "unkown error.\n";
        exit(4);
    }
    return 0;
}
rand_init.cpp

  makefile文件:

CC = g++
COMOPT = -std=c++11

INCLUDEDIR = -I./tools

LIBDIR = -L./tools
LIBS = -ltools
LINK = $(LIBDIR) $(LIBS)

# OBJS = $(patsubst %.cpp, %.o, $(wildcard *.cpp))
OBJS += chess.o astar.o state.o

OUTPUT += game rand_init

all: $(OUTPUT)

game: $(OBJS) main.o
    make -C tools
    $(CC) -o $@ $^ $(LINK)

rand_init: $(OBJS) rand_init.o
    make -C tools
    $(CC) -o $@ $^ $(LINK)

%.o: %.cpp
    $(CC) -o $@ -c $< $(COMOPT) $(INCLUDEDIR)

clean:
    make clean -C tools
    rm -f *.o
    rm -f $(OUTPUT)
makefile

  完整的八数码问题程序:astar_EightDigital

 

  然后是传教士过河问题,CState 类和 CAstar 类和上面一样,具体的业务实现类 CPersonState 如下:

#ifndef  __PERSON_H
#define  __PERSON_H

#include "state.h"
#include <iostream>
using std::ostream;

class CPersonState: public CState
{
    friend ostream& operator << (ostream&, const CPersonState&);
public:
    CPersonState();
    virtual bool operator < (const CState &) const;
    virtual vector<CState*> getNextState() const;
    virtual long astar_g() const;
    void init(int, int, int);
public:
    static int iTotalMissionary;    // the total number of missionaries
    static int iTotalSavage;        // the total number of savages
    static int iBoatCapacity;        // the capacity of the boat
private:
    int iMissionary;        // the number of missionaries in the shore where boat anchors
    int iSavage;            // the number of savages in the shore where boat anchors
    int iBoatPosition;        // the position of boat, this shore or opposite shore
public:
    enum POSITION
    {
        THIS_SHORE = 1, OPPOSITE_SHORE
    };
    int iMoveMissionary;
    int iMoveSavage;
};

#endif
person.h
#include "person.h"
#include "exception.h"
#include <algorithm>
using std::min;
using std::max;
#define  For(i,s,t)  for(auto i = (s); i != (t); ++i)

int CPersonState::iTotalMissionary = -1;
int CPersonState::iTotalSavage = -1;
int CPersonState::iBoatCapacity = -1;

CPersonState::CPersonState() {}

void CPersonState::init(int _m, int _s, int _b)
{
    iMissionary = _m;
    iSavage = _s;
    iBoatPosition = _b;
    iMoveMissionary = iMoveSavage = 0;

    if(iTotalMissionary == -1) {
        throw CException(101, "the total number of missionaries has not been initialized.");
    }
    if(iTotalSavage == -1) {
        throw CException(102, "the total number of savages has not been initialized.");
    }
    if(iBoatCapacity == -1) {
        throw CException(103, "the capacity of the boat has not been initialized.");
    }
    if(iMissionary > iTotalMissionary) {
        throw CException(104, "the number of missionaries on this shore exceeded the total number.");
    }
    if(iSavage > iTotalSavage) {
        throw CException(105, "the number of savages on this shore exceeded the total number.");
    }
    if(iMissionary && iMissionary < iSavage) {
        throw CException(106, "the number of missionaries can\'t be less then the number of savages.");
    }
    if(iBoatPosition != CPersonState::THIS_SHORE && iBoatPosition != CPersonState::OPPOSITE_SHORE) {
        throw CException(107, "the value of iBoat is invalid, which must be CPersonState::THIS_SHORE \
                            or CPersonState::OPPOSITE_SHORE, you can use 1 or 2 certainly.");
    }
}

bool CPersonState::operator < (const CState &rhs) const
{
    const CPersonState* const prhs = (CPersonState*)&rhs;
    if(iMissionary == prhs->iMissionary) {
        if(iSavage == prhs->iSavage) {
            return iBoatPosition < prhs->iBoatPosition;
        }
        return iSavage < prhs->iSavage;
    }
    return iMissionary < prhs->iMissionary;
}

using std::cin;
using std::cout;

vector<CState*> CPersonState::getNextState() const
{
    vector<CState*> nextState;    
    int oppo_m = iTotalMissionary - iMissionary;
    int oppo_s = iTotalSavage - iSavage;
    int mk = min(iMissionary, iBoatCapacity);
    int sk = min(iSavage, iBoatCapacity);
    For(x, 0, mk + 1) {
        For(y, 0, sk + 1) {
            if(!x && !y)    continue;
            if(iMissionary - x != 0 && iMissionary - x < iSavage - y)    continue;
            if(x + y > iBoatCapacity || (x && y > x) )    break;
            if(oppo_m + x != 0 && oppo_m + x < oppo_s + y)    break;
            CPersonState *_next = new CPersonState();
            // _next->init(iMissionary - x, iSavage - y, 3 - iBoatPosition);
            _next->init(oppo_m + x, oppo_s + y, 3 - iBoatPosition);
            _next->iMoveMissionary = x;
            _next->iMoveSavage = y;
            nextState.push_back(_next);
        }
    }
    return nextState;
}

long CPersonState::astar_g() const
{
    int remain_num;//, transport_num;
    if(iBoatPosition == CPersonState::THIS_SHORE) {
        remain_num = iMissionary + iSavage;
    } else {
        remain_num = (iTotalMissionary - iMissionary) + (iTotalSavage - iSavage);
    }
    return remain_num;// + transport_num;
}

ostream& operator << (ostream &out, const CPersonState &state)
{
    out << "(" << state.iMissionary << ", " << state.iSavage << ", " << state.iBoatPosition << ")";
    return out;
}
person.cpp

  main 函数:

#include "person.h"
#include "astar.h"
#include "exception.h"
#include <iostream>
using namespace std;

int main(int argc, char const *argv[])
{
    int m,s,k;
    while(true) {
        cout << "please input the number of missionaries, savages and the capacity of the boat, separate with a space:\n";
        if(bool(cin >> m >> s >> k) == false) {
            break;
        }
        CPersonState::iTotalMissionary = m;
        CPersonState::iTotalSavage = s;
        CPersonState::iBoatCapacity = k;
        CPersonState start, end;
        try {
            start.init(m, s, CPersonState::THIS_SHORE);
            end.init(m, s, CPersonState::OPPOSITE_SHORE);
        } catch(const CException &ex) {
            cerr << ex.code << ": " << ex.msg << "\n";
        } catch(...) {
            break;
        }
        CAstar game(start, end);
        game.run();
        if(game.bCanSolve == true) {
            cout << "your game can be solve:\n";
            cout << "the total steps is: " << game.iSteps << "\n";
            cout << "and the path is:\n";
            int len = game.vecSolve.size();
            cout << *((CPersonState*)(game.vecSolve[len - 1])) << "\n";
            for(int i = len - 2; i >= 0; --i) {
                cout << "\n    |\n";
                auto pstate = (CPersonState*)(game.vecSolve[i]);
                cout << "    |   (" << pstate->iMoveMissionary << ", " << pstate->iMoveSavage << ")\n";
                cout << "   \\|/\n\n";
                cout << *((CPersonState*)(game.vecSolve[i])) << "\n";
            }
        } else {
            cout << "sorry, your game can't be solve, please input another state.\n\n";
        }
        cout << "the total steps is: " << game.iSteps << "\n";
        cout << "and the max states is: " << game.iTotalStates << "\n";
        cout << "and the runtime is: " << game.lRunTime << "\n";
    }
    return 0;
}
main.cpp

  makefile 文件(和上面的相似,只是编译的目标项稍有不同):

CC = g++
COMOPT = -std=c++11 -g

INCLUDEDIR = -I./tools

LIBDIR = -L./tools
LIBS = -ltools
LINK = $(LIBDIR) $(LIBS)

# OBJS = $(patsubst %.cpp, %.o, $(wildcard *.cpp))
OBJS += person.o astar.o state.o

OUTPUT += across_river

all: $(OUTPUT)

across_river: $(OBJS) main.o
    make -C tools
    $(CC) -o $@ $^ $(LINK)

%.o: %.cpp
    $(CC) -o $@ -c $< $(COMOPT) $(INCLUDEDIR)

clean:
    make clean -C tools
    rm -f *.o
    rm -f $(OUTPUT)
makefile

  之后我用 Java 来重写,除了面向对象的语法有区别以外其它都几乎是一样的:

首先是自定义异常类 MyException:

package tools;

/**
 * 自定义的异常类,错误码和错误信息的简单封装
 */
public class MyException extends RuntimeException {
    
    private static final long serialVersionUID = 1L;
    
    public int code;        // 错误码
    public String msg;        // 错误信息
    
    public MyException(int code, String msg) {
        super();
        this.code = code;
        this.msg = msg;
    }

    @Override
    public String toString() {
        return "MyException [code=" + code + ", msg=" + msg + "]";
    }
    
}
MyException.java

计时类 TimeValue:

package tools;

public class TimeValue {
    
    private long milliSecond;
    
    /**
     * 初始化时获取当前系统时间(millisecond)
     */
    public TimeValue() {
        super();
        this.milliSecond = System.currentTimeMillis();
    }

    public TimeValue(long milliSecond) {
        super();
        this.milliSecond = milliSecond;
    }
    
    /**
     * 返回耗时,以毫秒为单位
     */
    public long costTime() {
        long nowMilliSecond = System.currentTimeMillis();
        return nowMilliSecond - this.milliSecond;
    }
    
    /**
     * 重置时间为当前时间
     */
    public void reset() {
        this.milliSecond = System.currentTimeMillis();
    }

    @Override
    public String toString() {
        return "TimeValue [milliSecond=" + milliSecond + "]";
    }
    
}
TimeValue.java

抽象类 State:

package main;

import java.util.ArrayList;

import tools.MyException;

/**
 * 表示状态的抽象类
 */
public abstract class State {

    public int steps;
    public State parent;
    
    public State() {
        super();
        this.steps = 0;
        this.parent = null;
    }

    abstract public int hashCode();

    abstract public boolean equals(Object obj);
    
    public void checkSomeFields(State rhs) throws MyException {}
    
    abstract public ArrayList<State> getNextState();
    
    public ArrayList<State> __getNextState() {
        ArrayList<State> nextState = this.getNextState();
        for(State st: nextState) {
            st.steps = this.steps + 1;
            st.parent = this;
        }
        return nextState;
    }
    
    public long astar_f() {
        return this.steps;
    }
    
    abstract public long astar_g();
    
    public long astar_h() {
        return this.astar_f() + this.astar_g();
    }
}

A*算法类 Astar:

package main;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import java.util.Comparator;
import java.util.HashMap;

import tools.MyException;
import tools.TimeValue;

/**
 * astar 算法主体类
 */
public class Astar {
    
    public State start;
    public State end;
    public boolean canSolve;
    public int steps;
    public ArrayList<State> vecSolve;
    public long runTime;
    public int totalStates;
    
    public Astar(State start, State end) throws MyException {
        super();
        this.start = start;
        this.end = end;
        this.canSolve = false;
        this.steps = 0;
        this.vecSolve = new ArrayList<State>();
        this.runTime = 0;
        this.totalStates = 0;
        start.checkSomeFields(end);
    }
    
    static Set<State> getStateByStartAndSteps(State start, int steps) {
        Set<State> retSet = new HashSet<>();
        // 以后再补充,懒得把c++代码翻译了
        return retSet;
    }
    
    void run() {
        TimeValue _time = new TimeValue();
        Map<State, State> mapState = new HashMap<>();
        mapState.put(this.start, this.start);
        // 最小堆
        Queue<State> queState = new PriorityQueue<>(new Comparator<State>() {
            @Override
            public int compare(State o1, State o2) {
                long diff = o1.astar_h() - o2.astar_h();
                return diff == 0 ? 0: (diff > 0 ? 1: -1);
            }
        });
        queState.add(this.start);
        
        while(!queState.isEmpty()) {
            State headState = queState.poll();
            if(headState.equals(this.end)) {
                this.canSolve = true;
                this.steps = headState.steps;
                this.vecSolve.add(headState);
                State lastState = headState.parent;
                while(lastState != null) {
                    this.vecSolve.add(lastState);
                    lastState = lastState.parent; 
                }
                break;
            }
            ArrayList<State> nextState = headState.__getNextState();
            for(State _next: nextState) {
                State state = mapState.get(_next);
                if(state == null) {
                    queState.add(_next);
                    mapState.put(_next, _next);
                } else {
                    if(state.astar_f() > _next.astar_f()) {
                        mapState.remove(_next);
                        mapState.put(_next, _next);
                        queState.add(_next);
                    }
                }
            }
            if(mapState.size() > 3000 * 10000) {
                break;
            }
        }
        this.totalStates = mapState.size();
        this.runTime = _time.costTime();
    }
}

用于表示传教士过河状态的具体类 PersonState:

package main;

import java.util.ArrayList;

import tools.MyException;

public class PersonState extends State {
    
    static public int totalMissionary;    // the total number of missionaries
    static public int totalSavage;        // the total number of savages
    static public int boatCapacity;        // the capacity of the boat
    
    private int missionary;        // the number of missionaries in the shore where boat anchors
    private int savage;            // the number of savages in the shore where boat anchors
    private int boatPosition;        // the position of boat, this shore or opposite shore
    
    public static final int THIS_SHORE = 1;
    public static final int OPPOSITE_SHORE = 2;
    
    public int moveMissionary;
    public int moveSavage;

    public PersonState() {
        super();
    }
    
    public void init(int _m, int _s, int _b) throws MyException {
        this.missionary = _m;
        this.savage = _s;
        this.boatPosition = _b;
        this.moveMissionary = moveSavage = 0;

        if(totalMissionary == -1) {
            throw new MyException(101, "the total number of missionaries has not been initialized.");
        }
        if(totalSavage == -1) {
            throw new MyException(102, "the total number of savages has not been initialized.");
        }
        if(boatCapacity == -1) {
            throw new MyException(103, "the capacity of the boat has not been initialized.");
        }
        if(missionary > totalMissionary) {
            throw new MyException(104, "the number of missionaries on this shore exceeded the total number.");
        }
        if(savage > totalSavage) {
            throw new MyException(105, "the number of savages on this shore exceeded the total number.");
        }
        if(missionary != 0 && missionary < savage) {
            throw new MyException(106, "the number of missionaries can\'t be less then the number of savages.");
        }
        if(boatPosition != THIS_SHORE && boatPosition != OPPOSITE_SHORE) {
            throw new MyException(107, "the value of iBoat is invalid, which must be CPersonState::THIS_SHORE or CPersonState::OPPOSITE_SHORE, you can use 1 or 2 certainly.");
        }
    }

    @Override
    public String toString() {
        return "PersonState [missionary=" + missionary + ", savage=" + savage + ", boatPosition=" + boatPosition
                + ", moveMissionary=" + moveMissionary + ", moveSavage=" + moveSavage + "]";
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + boatPosition;
        result = prime * result + missionary;
        result = prime * result + moveMissionary;
        result = prime * result + moveSavage;
        result = prime * result + savage;
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        PersonState other = (PersonState) obj;
        if (boatPosition != other.boatPosition)
            return false;
        if (missionary != other.missionary)
            return false;
        if (savage != other.savage)
            return false;
        return true;
    }

    @Override
    public ArrayList<State> getNextState() {
        ArrayList<State> nextState = new ArrayList<>();
        int oppo_m = totalMissionary - missionary;
        int oppo_s = totalSavage - savage;
        int mk = Math.min(missionary, boatCapacity);
        int sk = Math.min(savage, boatCapacity);
        for(int x = 0; x <= mk; ++x) {
            for(int y = 0; y <= sk; ++y) {
                if(x == 0 && y == 0)    continue;
                if(missionary - x != 0 && missionary - x < savage - y)    continue;
                if(x + y > boatCapacity || (x != 0 && y > x) )    break;
                if(oppo_m + x != 0 && oppo_m + x < oppo_s + y)    break;
                PersonState _next = new PersonState();
                _next.init(oppo_m + x, oppo_s + y, 3 - boatPosition);
                _next.moveMissionary = x;
                _next.moveSavage = y;
                nextState.add(_next);
            }
        }
        return nextState;
    }

    @Override
    public long astar_g() {
        int remain_num;//, transport_num;
        if(boatPosition == THIS_SHORE) {
            remain_num = missionary + savage;
        } else {
            remain_num = (totalMissionary - missionary) + (totalSavage - savage);
        }
        return remain_num;// + transport_num;
    }

}
PersonState.java

最后是 main 函数,实现简单的交互:

package main;

import java.util.Scanner;

import tools.MyException;

public class Main {

    public static void main(String[] args) {
        int m = 0, s = 0, k = 0;
        Scanner cin = new Scanner(System.in);
        while(true) {
            System.out.println("please input the number of missionaries, savages and the capacity of the boat, separate with a space:");
            try {
                m = cin.nextInt();
                s = cin.nextInt();
                k = cin.nextInt();
            } catch(Exception ex) {
                System.out.println(ex.toString() + "\nbye~");
                break;
            }
            PersonState.totalMissionary = m;
            PersonState.totalSavage = s;
            PersonState.boatCapacity = k;
            PersonState start = new PersonState();
            PersonState end = new PersonState();
            try {
                start.init(m, s, PersonState.THIS_SHORE);
                end.init(m, s, PersonState.OPPOSITE_SHORE);
            } catch (MyException ex) {
                System.out.println(ex.toString());
            } catch(Exception ex) {
                break;
            }
            Astar game = new Astar(start, end);
            game.run();
            if(game.canSolve == true) {
                System.out.println("your game can be solve:");
                System.out.println("the total steps is: " + game.steps);
                System.out.println("and the path is:\n");
                int len = game.vecSolve.size();
                System.out.println(game.vecSolve.get(len - 1).toString());
                for(int i = len - 2; i >= 0; --i) {
                    System.out.println("\n    |");
                    PersonState personState = (PersonState)game.vecSolve.get(i);
                    System.out.println("    |   (" + personState.moveMissionary + ", " + personState.moveSavage + ")");
                    System.out.println("   \\|/\n");
                    System.out.println(personState.toString());
                }
                System.out.println();
            } else {
                System.out.println("sorry, your game can't be solve, please input another state.\n");
            }
            System.out.println("the total steps is: " + game.steps);
            System.out.println("and the max states is: " + game.totalStates);
            System.out.println("and the runtime is: " + game.runTime + "\n");
        }
        cin.close();
    }

}
main.java

  完整的传教士过河代码:missionary_across_river

posted @ 2016-12-13 16:10  Newdawn_ALM  阅读(1188)  评论(0编辑  收藏  举报