基于朴素贝叶斯的扑克牌花色识别

本程序只对扑克牌的花色进行训练和识别,对扑克牌上的数字的识别在以后的学习中再进行完善。

本次只是简单的提取了扑克牌的RGB均值、HSV均值、7 个不变矩以及长宽比等14个简单的特征,其中,长宽比为了防止图像的位置等因素的影响,提取了目标区域的最小外接矩形。

部分图像如下图所示:

 

特征提取的部分代码如下所示:

[cpp] view plain copy 
1.void CPokeAlgorithmDlg::CollectCharacter(IplImage* img, CvMat* mat, int rows)  
2.{  
3.    if (img != nullptr)  
4.    {  
5.        showImage(img, IDC_PIC1);                   //显示图像  
6.  
7.        IplImage* bitImage = nullptr, *grayImage = nullptr, *hsvImage = nullptr;  
8.  
9.        bitImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 1);  
10.        grayImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 1);  
11.        hsvImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 3);  
12.  
13.  
14.        cvCvtColor(img, hsvImage, CV_RGB2HSV);  
15.        cvCvtColor(img, grayImage, CV_RGB2GRAY);  
16.  
17.        cvSmooth(grayImage, grayImage, CV_MEDIAN);  
18.        cvThreshold(grayImage, bitImage, 128, 255.0, CV_THRESH_BINARY);  
19.  
20.        cvNot(bitImage, bitImage);  
21.  
22.        IplConvKernel* element = cvCreateStructuringElementEx(5, 5, 2, 2, CV_SHAPE_ELLIPSE);  
23.        cvSmooth(bitImage, bitImage, CV_MEDIAN);  
24.        cvErode(bitImage, bitImage, element, 1);  
25.        cvDilate(bitImage, bitImage, element, 1);  
26.        cvReleaseStructuringElement(&element);  
27.        element = NULL;  
28.  
29.  
30.        CvMemStorage* storage = cvCreateMemStorage(0);  
31.        CvSeq* contour = 0;  
32.        cvFindContours(bitImage, storage, &contour, sizeof(CvContour), CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE);     //轮廓检索  
33.  
34.        for (; contour != 0; contour = contour->h_next)  
35.        {  
36.            double area = fabs(cvContourArea(contour, CV_WHOLE_SEQ));  
37.  
38.            if (area > 2000)   //此处阈值需重新调节  
39.            {  
40.                cvDrawContours(bitImage, contour, cvScalarAll(255), cvScalarAll(255), -1, CV_FILLED, 8);  
41.                CvRect rect = cvBoundingRect(contour, 0);  
42.  
43.                CvBox2D minRect = cvMinAreaRect2(contour, storage);  
44.  
45.                CvPoint2D32f rectPts[4] = { 0 };  
46.                cvBoxPoints(minRect, rectPts);  
47.                int nPts = 4;   // 4 个顶点  
48.  
49.                CvPoint minRectPts[4] = { 0 };  
50.                for (int i = 0; i < 4; ++i)  
51.                {  
52.                    minRectPts[i] = cvPointFrom32f(rectPts[i]);    //将 cvPoint2D32f 转化为 CvPoint  
53.                }  
54.                CvPoint *pt = minRectPts;  
55.  
56.                //在图像中绘制矩形框  
57.                cvPolyLine(bitImage, &pt, &nPts, 1, 1, cvScalarAll(255), 1);  
58.  
59.                int l1 = sqrtf((pt[0].x - pt[1].x)*(pt[0].x - pt[1].x) + (pt[0].y - pt[1].y)*(pt[0].y - pt[1].y));  
60.                int l2 = sqrtf((pt[2].x - pt[1].x)*(pt[2].x - pt[1].x) + (pt[2].y - pt[1].y)*(pt[2].y - pt[1].y));  
61.  
62.                int length = l1 > l2 ? l1 : l2;   //取较长边为图形的长  
63.                int width = l1 > l2 ? l2 : l1;     //取较短边为图形的宽  
64.  
65.                double r = (width * 1.0) / length;   //长宽比  
66.  
67.                cvSetReal2D(mat, rows, 0, r);  
68.              
69.                double RMean = 0, GMean = 0, BMean = 0;  
70.                double HMean = 0, SMean = 0, VMean = 0;  
71.                int nCount = 0;  
72.  
73.                for (int imgRow = rect.y; imgRow < rect.y + rect.height; ++imgRow)  
74.                {  
75.                    for (int imgCol = rect.x; imgCol < rect.x + rect.width; ++imgCol)  
76.                    {  
77.                        CvScalar s = cvGet2D(bitImage, imgRow, imgCol);  
78.  
79.                        if (s.val[0] == 255)  
80.                        {  
81.                            s = cvGet2D(img, imgRow, imgCol);  
82.                            RMean += s.val[2];  
83.                            GMean += s.val[1];  
84.                            BMean += s.val[0];  
85.  
86.                            s = cvGet2D(hsvImage, imgRow, imgCol);  
87.                            HMean += s.val[0];  
88.                            SMean += s.val[1];  
89.                            VMean += s.val[2];  
90.  
91.                            ++nCount;  
92.                        }  
93.                    }  
94.                }// end RGB,HSV for  
95.  
96.                RMean /= nCount;  
97.                GMean /= nCount;  
98.                BMean /= nCount;  
99.  
100.                HMean /= nCount;  
101.                SMean /= nCount;  
102.                VMean /= nCount;  
103.  
104.                  
105.                cvSetReal2D(mat, rows, 1, RMean);  
106.                cvSetReal2D(mat, rows, 2, GMean);  
107.                cvSetReal2D(mat, rows, 3, BMean);  
108.                cvSetReal2D(mat, rows, 4, HMean);  
109.                cvSetReal2D(mat, rows, 5, SMean);  
110.                cvSetReal2D(mat, rows, 6, VMean);  
111.  
112.                //7个不变矩  
113.  
114.                CvMoments moments;  
115.                cvMoments(contour, &moments, 1);  
116.                CvHuMoments  huMoments;  
117.                cvGetHuMoments(&moments, &huMoments);  
118.  
119.                double hu1 = huMoments.hu1;  
120.                double hu2 = huMoments.hu2;  
121.                double hu3 = huMoments.hu3;  
122.                double hu4 = huMoments.hu4;  
123.                double hu5 = huMoments.hu5;  
124.                double hu6 = huMoments.hu6;  
125.                double hu7 = huMoments.hu7;  
126.  
127.                cvSetReal2D(mat, rows, 7, hu1);  
128.                cvSetReal2D(mat, rows, 8, hu2);  
129.                cvSetReal2D(mat, rows, 9, hu3);  
130.                cvSetReal2D(mat, rows, 10, hu4);  
131.                cvSetReal2D(mat, rows, 11, hu5);  
132.                cvSetReal2D(mat, rows, 12, hu6);  
133.                cvSetReal2D(mat, rows, 13, hu7);  
134.            }// end if  
135.        }  
136.  
137.        showImage(hsvImage, IDC_PIC3);  
138.        showImage(bitImage, IDC_PIC2);  
139.  
140.  
141.        //释放内存  
142.        cvReleaseMemStorage(&storage);  
143.        storage = nullptr;  
144.        cvReleaseImage(&bitImage);  
145.        bitImage = nullptr;  
146.        cvReleaseImage(&grayImage);  
147.        grayImage = nullptr;  
148.        cvReleaseImage(&hsvImage);  
149.        hsvImage = nullptr;  
150.    }  
151.  
152.    //释放内存  
153.    cvReleaseImage(&img);  
154.    img = nullptr;  
155.}  

 Bayes训练代码:

[cpp] view plain copy 
1.Book* book = xlCreateXMLBookW();  
2.  
3.    CvMat* dataMat = NULL;  
4.  
5.    if (book->load(L"Data.xlsx"))  
6.    {  
7.        Sheet *sheet = book->getSheet(0);  
8.  
9.        int myrow = sheet->lastRow();  
10.        int mycol = sheet->lastCol();  
11.  
12.        if (sheet)  
13.        {  
14.            CvMat* importMat = cvCreateMat(myrow, mycol, CV_32FC1);  //存储导入数据  
15.  
16.            for (auto i = 0; i < myrow; ++i)  
17.            {  
18.                for (auto j = 0; j < mycol; j++)  
19.                {  
20.                    double temp = sheet->readNum(i, j);  
21.                    cvSetReal2D(importMat, i, j, temp);  
22.                }  
23.            }// end for  
24.  
25.            dataMat = cvCloneMat(importMat);  
26.        }// end if  
27.    }  
28.  
29.    book->release();  
30.  
31.    MessageBox(L"数据导入完成");  
32.  
33.    CvMat* lableMat = cvCreateMat(dataMat->rows, 1, CV_32FC1);       //构建样本的分类标签  
34.    cvZero(lableMat);  
35.  
36.    for (int i = 0; i < 4; ++i)          //共分了 20 个不同的种类  
37.    {  
38.        for (int j = 0; j < 10; ++j)     //每个品种共50个籽粒  
39.        {  
40.            cvSetReal2D(lableMat, i * 10 + j, 0, i + 1);  
41.        }  
42.    }  
43.  
44.    CvNormalBayesClassifier nbc;  
45.    nbc.train(dataMat, lableMat);  
46.    nbc.save("bayes.txt");  
47.  
48.    MessageBox(L"数据训练完成");  
49.  
50.    CvMat* nbcResult = cvCreateMat(dataMat->rows, 1, CV_32FC1);  
51.    CvMat* nbcRow = NULL;  
52.  
53.    for (int i = 0; i < dataMat->rows; ++i)  
54.    {  
55.        nbcRow = cvCreateMat(1, dataMat->cols, CV_32FC1);  
56.  
57.        for (int j = 0; j < dataMat->cols; ++j)  
58.        {  
59.            float temp = cvGetReal2D(dataMat, i, j);  
60.            cvSetReal2D(nbcRow, 0, j, temp);  
61.        }  
62.  
63.        unsigned int ret = 0;  
64.        ret = nbc.predict(nbcRow);  
65.        cvSetReal2D(nbcResult, i, 0, ret);  
66.        cvReleaseMat(&nbcRow);  
67.        nbcRow = NULL;  
68.    }  
69.  
70.    int nCount = 0;  
71.  
72.    for (int i = 0; i < 4; ++i)  
73.    {  
74.        for (int j = 0; j < 10; ++j)  
75.        {  
76.            int ret = cvGetReal2D(nbcResult, i * 10 + j, 0);  
77.            if (ret == (i + 1))  
78.            {  
79.                ++nCount;  
80.            }  
81.        }  
82.    }  
83.  
84.    float recognize = 100 * nCount / 10 / 4;  
85.  
86.    CString str;  
87.    str.Format(L"朴素贝叶斯 识别率为: %f", recognize);  
88.    str = str + L"%";  
89.    MessageBox(str);  
90.      

 识别代码如下所示:

[html] view plain copy 
1.CvNormalBayesClassifier nbc;  
2.    nbc.load("bayes.txt");  
3.  
4.    CFileDialog dlg(TRUE, NULL, NULL, 0, L"图片文件(*.jpg)|*.jpg||");  
5.    if (dlg.DoModal() == IDOK)  
6.    {  
7.        USES_CONVERSION;  
8.        const char* loadPath = W2A(dlg.GetPathName());  
9.        IplImage* testImage = cvLoadImage(loadPath);  
10.  
11.        CvMat* mat = cvCreateMat(1, 14, CV_32FC1);  
12.        CollectCharacter(testImage, mat, 0);  
13.  
14.        int  ret = nbc.predict(mat);  
15.        CString str;  
16.        switch (ret)  
17.        {  
18.        case 1:  
19.            str = L"黑桃";  
20.            break;  
21.        case 2:  
22.            str = "红桃";  
23.            break;  
24.        case 3:  
25.            str = "梅花";  
26.            break;  
27.        case 4:  
28.            str = "方块";  
29.            break;  
30.        }  
31.        AfxMessageBox(str);  
32.        cvReleaseMat(&mat);  
33.        mat = NULL;  
34.    }//end if     

 

posted @ 2017-11-01 14:35  2206  阅读(1390)  评论(0编辑  收藏  举报