C++实现Prim算法

闲来无聊,前两天看到一篇关于算法实现的文章。里面又关于图的各种算法介绍,正好上学期还学过图论,现在还记得一点点,先来实现个prim算法:

表示图的文件的内容大体上是这样的:

 1 2.0    1.0    1.0
 2 3.0    1.0    1.0
 3 4.0    1.0    1.0
 4 5.0    1.0    1.0
 5 6.0    1.0    1.0
 6 7.0    1.0    1.0
 7 8.0    1.0    1.0
 8 9.0    1.0    1.0
 9 11.0 1.0 1.0
10 12.0 1.0 1.0
11 13.0 1.0 1.0
12 14.0 1.0 1.0
13 18.0 1.0 1.0
14 20.0 1.0 1.0
15 22.0 1.0 1.0
16 32.0 1.0 1.0
17 34.0 10.0 1.0
18 34.014.0 1.0
19 33.0 15.0 1.0
20 34.0 15.0 1.0
21 33.0 16.0 1.0
22 34.0 16.0 1.0
23 33.0 19.0 1.0
24 34.0 19.0 1.0
25 3.0    2.0    1.0
26 4.0    2.0    1.0
27 8.0    2.0    1.0
28 14.0 2.0 1.0
29 18.0 2.0 1.0
30 20.0 2.0 1.0
31 22.0 2.0 1.0
32 31.0 2.0 1.0
33 34.0 20.0 1.0
34 33.0 21.0 1.0
35 34.0 21.0 1.0
36 33.0 23.0 1.0
37 34.0 23.0 1.0
38 26.0 24.0 1.0
39 28.0 24.0 1.0
40 30.0 24.0 1.0
41 33.0 24.0 1.0
42 34.0 24.0 1.0
43 26.0 25.0 1.0
44 28.0 25.0 1.0
45 32.0 25.0 1.0
46 32.0 26.0 1.0
47 30.0 27.0 1.0
48 34.0 27.0 1.0
49 34.0 28.0 1.0
50 32.0 29.0 1.0
51 34.0 29.0 1.0
52 4.0    3.0    1.0
53 8.0    3.0    1.0
54 9.0    3.0    1.0
55 10.0 3.0 1.0
56 14.0 3.0 1.0
57 28.0 3.0 1.0
58 29.0 3.0 1.0
59 33.0 3.0 1.0
60 33.0 30.0 1.0
61 34.0 30.0 1.0
62 33.0 31.0 1.0
63 34.0 31.0 1.0
64 33.0 32.0 1.0
65 34.0 32.0 1.0
66 34.0 33.0 1.0
67 8.0    4.0    1.0
68 13.0 4.0 1.0
69 14.0 4.0 1.0
70 7.0    5.0    1.0
71 11.0 5.0 1.0
72 7.0    6.0    1.0
73 11.0 6.0 1.0
74 17.0 6.0 1.0
75 17.0 7.0 1.0
76 31.0 9.0 1.0
77 33.0 9.0 1.0
78 34.0 9.0 1.0
View Code

注意,从左到右分别是当前节点,连接的节点,边的权重,下面首先就是设计数据结构了:

1 class Pair {        //pair代表了与某个点相连的一条边的权重
2 private:            //,以及和这条变相连的另一个顶点是哪个
3     double edge_weight;
4     int adacent_vertex;
5 public:
6     Pair(int, double);
7     double weight() const;
8     int vertex() const;
9 };

上面的pair代表一个点相邻的边的权重以及这条边与哪一个顶点是相连的。

1 class Node { //代表了一个节点,其包含的信息有与其相连的
2 private:     //某一条边的权重以及和这条边相连的另一个顶点。 
3     Pair element;
4     Node *next_node;
5 public:
6     Node(Pair e, Node * = NULL);
7     Pair retrieve() const;
8     Node *next() const;
9 };

代表一个节点,注意这个节点的next_node的值与相邻节点没有任何关系,只是代表链表的下一个节点,下面介绍的是链表:

 1 class List { //List中存放的是每个具体的节点,
 2 private:     //每个节点后面的链表代表的是与其相邻接的节点
 3     Node *list_head;
 4 public:
 5     List();
 6     // Accessors
 7     bool empty() const;
 8     Pair front() const;
 9     Node *head() const;
10     void push_front(Pair);
11     Pair pop_front();
12     void print();
13 };

下面的cell实际上代表的就是一颗生成树了:

 1 class Cell { //cell代表的就是一个具体的生成树了
 2 private:
 3     bool visited;
 4     double distance;
 5     int parent;
 6 public:
 7     Cell(bool = false, double = INFTY, int = 0);
 8     bool isvisited() const;
 9     double get_distance() const;
10     int get_parent() const;
11 };

-----------------------------------------------------------------------------------------------------------

下面是数据结构的具体定义了:

  1 #include "structure.h"
  2 
  3 Pair::Pair(int e, double m) :
  4 edge_weight(m),
  5 adacent_vertex(e) {
  6     // empty constructor
  7 }
  8 
  9 double Pair::weight()const {
 10     return edge_weight;
 11 }
 12 int Pair::vertex()const {
 13     return adacent_vertex;
 14 }
 15 
 16 
 17 Node::Node(Pair e, Node *n) :
 18 element(e),
 19 next_node(n) {
 20     // empty constructor
 21 }
 22 
 23 Pair Node::retrieve() const{
 24     return element;
 25 }
 26 Node *Node::next() const {
 27     return next_node;
 28 }
 29 
 30 
 31 List::List() :list_head(NULL) {
 32     // empty constructor
 33 }
 34 
 35 bool List::empty() const {
 36     if (list_head == NULL) {
 37         return true;
 38     }
 39     else {
 40         return false;
 41     }
 42 }
 43 
 44 Node *List::head() const {
 45     return list_head;
 46 
 47 }
 48 
 49 Pair List::front() const {
 50 
 51     if (empty()) {
 52         cout << "error! the list is empty";
 53     }
 54     return head()->retrieve();
 55 }
 56 
 57 
 58 void List::push_front(Pair e) {
 59     if (empty()) {
 60         list_head = new Node(e, NULL);
 61     }
 62     else {
 63         list_head = new Node(e, head());
 64     }
 65 
 66 }
 67 
 68 Pair List::pop_front() {
 69     if (empty()) {
 70         cout << "error! the list is empty";
 71     }
 72     Pair e = front();
 73     Node *ptr = list_head;
 74     list_head = list_head->next();
 75     delete ptr;
 76     return e;
 77 
 78 }
 79 
 80 
 81 void List::print() {
 82     if (empty()) {
 83         cout << "empty" << endl;
 84     }
 85     else {
 86         for (Node *ptr = head(); ptr != NULL; ptr = ptr->next())
 87         {
 88             cout << "<" << ptr->retrieve().vertex() << " " << ptr->retrieve().weight() << "> ";
 89         }
 90         cout << endl;
 91     }
 92 }
 93 
 94 
 95 Cell::Cell(bool v, double d, int p) :
 96 visited(v),
 97 distance(d),
 98 parent(p) {
 99     // empty constructor
100 }
101 
102 bool Cell::isvisited() const {
103     return visited;
104 }
105 
106 double Cell::get_distance()const {
107     return distance;
108 }
109 int Cell::get_parent()const {
110     return parent;
111 }

 

好了有了上面的数据结构,实现Prim算法就比较简单了:

 1 Cell* Prim(List * graph, int n, int start)  //传入一个邻接数组,以及数组的大小,
 2 {                                            //以及邻接矩阵的起始点,求一个最小生成树
 3     Cell* Table = new Cell[n + 1]; //n+1的目的是节点是从1开始数的,所以要多添加一个
 4     //Table[start]=Cell(false,0,0);//这里的false是否换成True会好一点?
 5     Table[start] = Cell(true, 0, 0);
 6     /* 实现prim算法*/
 7     int currMinNode, currParent = 0;
 8     double currMin;
 9     for (;;){
10         currMin = INFTY; //注意这个的类型是double类型
11         currMinNode = 0;
12         currParent = 0;
13         for (int i = 1; i <= n; ++i){
14             if (Table[i].isvisited()){//已经被访问过了
15                 Node * currNode = graph[i].head();
16                 while (currNode != NULL){ //从该节点开始,然后访问其所有的邻接的节点
17                     int tmpNode = currNode->retrieve().vertex();
18                     double tmpWeight = currNode->retrieve().weight();
19                     if (!Table[tmpNode].isvisited() && tmpWeight < currMin){
20                         currMin = tmpWeight;
21                         currMinNode = tmpNode;
22                         currParent = i;
23                     }
24                     currNode = currNode->next(); //取下一个邻接的节点
25                 }
26             }
27             else
28                 continue;
29 
30         }
31         Table[currMinNode] = Cell(true, currMin, currParent);//找到下一个节点,将其置为True
32         if (currMinNode == 0) //如果所有的节点都已经遍历完毕的话,就停止下一次的寻找
33             break;
34 
35     }
36     return Table;
37 }

 

顺手写个打印生成树的函数:

1 void PrintTable(Cell* Table, int n)
2 {
3     for (int i = 1; i <= n; i++)
4         cout << Table[i].isvisited() << " " <<
5         Table[i].get_distance() << " " <<
6         Table[i].get_parent() << endl;
7 }

主函数如下所示:

 1 #include "structure.h"
 2 #include "Prim.h"
 3 int main()
 4 {
 5     List * graph = new List[N];
 6     char *inputfile = "primTest.txt";
 7     ifstream fin; //输入文件 .join后结果
 8     int n = 0;
 9     double x, y, w;
10     fin.open(inputfile);
11     while (!fin.eof())
12     {
13         fin >> x >> y >> w;
14         Pair a(y, w);
15         Pair b(x, w);
16         graph[int(x)].push_front(a);
17         graph[int(y)].push_front(b);
18         if (n <= x)
19             n = x;
20         if (n <= y)
21             n = y;
22     }
23     fin.close();
24     cout << "The total Node number is "
25         << n << endl;
26     for (int i = 1; i <= n; i++)
27         graph[i].print();
28 
29     Cell* Table = Prim(graph, n, 2);
30     cout << "-----------------------\n";
31     PrintTable(Table, n);33     return 0;
34 }

 最后的结果如下所示:

写的有点乱,见谅见谅

其实这个也可以实现Dijkstra算法,那个好像没学过,看看以后有时间再来写。

posted @ 2015-12-27 22:24  eversliver  阅读(5086)  评论(0编辑  收藏  举报