基于朴素贝叶斯的扑克牌花色识别
本程序只对扑克牌的花色进行训练和识别,对扑克牌上的数字的识别在以后的学习中再进行完善。
本次只是简单的提取了扑克牌的RGB均值、HSV均值、7 个不变矩以及长宽比等14个简单的特征,其中,长宽比为了防止图像的位置等因素的影响,提取了目标区域的最小外接矩形。
部分图像如下图所示:
特征提取的部分代码如下所示:
[cpp] view plain copy 1.void CPokeAlgorithmDlg::CollectCharacter(IplImage* img, CvMat* mat, int rows) 2.{ 3. if (img != nullptr) 4. { 5. showImage(img, IDC_PIC1); //显示图像 6. 7. IplImage* bitImage = nullptr, *grayImage = nullptr, *hsvImage = nullptr; 8. 9. bitImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 1); 10. grayImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 1); 11. hsvImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 3); 12. 13. 14. cvCvtColor(img, hsvImage, CV_RGB2HSV); 15. cvCvtColor(img, grayImage, CV_RGB2GRAY); 16. 17. cvSmooth(grayImage, grayImage, CV_MEDIAN); 18. cvThreshold(grayImage, bitImage, 128, 255.0, CV_THRESH_BINARY); 19. 20. cvNot(bitImage, bitImage); 21. 22. IplConvKernel* element = cvCreateStructuringElementEx(5, 5, 2, 2, CV_SHAPE_ELLIPSE); 23. cvSmooth(bitImage, bitImage, CV_MEDIAN); 24. cvErode(bitImage, bitImage, element, 1); 25. cvDilate(bitImage, bitImage, element, 1); 26. cvReleaseStructuringElement(&element); 27. element = NULL; 28. 29. 30. CvMemStorage* storage = cvCreateMemStorage(0); 31. CvSeq* contour = 0; 32. cvFindContours(bitImage, storage, &contour, sizeof(CvContour), CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE); //轮廓检索 33. 34. for (; contour != 0; contour = contour->h_next) 35. { 36. double area = fabs(cvContourArea(contour, CV_WHOLE_SEQ)); 37. 38. if (area > 2000) //此处阈值需重新调节 39. { 40. cvDrawContours(bitImage, contour, cvScalarAll(255), cvScalarAll(255), -1, CV_FILLED, 8); 41. CvRect rect = cvBoundingRect(contour, 0); 42. 43. CvBox2D minRect = cvMinAreaRect2(contour, storage); 44. 45. CvPoint2D32f rectPts[4] = { 0 }; 46. cvBoxPoints(minRect, rectPts); 47. int nPts = 4; // 4 个顶点 48. 49. CvPoint minRectPts[4] = { 0 }; 50. for (int i = 0; i < 4; ++i) 51. { 52. minRectPts[i] = cvPointFrom32f(rectPts[i]); //将 cvPoint2D32f 转化为 CvPoint 53. } 54. CvPoint *pt = minRectPts; 55. 56. //在图像中绘制矩形框 57. cvPolyLine(bitImage, &pt, &nPts, 1, 1, cvScalarAll(255), 1); 58. 59. int l1 = sqrtf((pt[0].x - pt[1].x)*(pt[0].x - pt[1].x) + (pt[0].y - pt[1].y)*(pt[0].y - pt[1].y)); 60. int l2 = sqrtf((pt[2].x - pt[1].x)*(pt[2].x - pt[1].x) + (pt[2].y - pt[1].y)*(pt[2].y - pt[1].y)); 61. 62. int length = l1 > l2 ? l1 : l2; //取较长边为图形的长 63. int width = l1 > l2 ? l2 : l1; //取较短边为图形的宽 64. 65. double r = (width * 1.0) / length; //长宽比 66. 67. cvSetReal2D(mat, rows, 0, r); 68. 69. double RMean = 0, GMean = 0, BMean = 0; 70. double HMean = 0, SMean = 0, VMean = 0; 71. int nCount = 0; 72. 73. for (int imgRow = rect.y; imgRow < rect.y + rect.height; ++imgRow) 74. { 75. for (int imgCol = rect.x; imgCol < rect.x + rect.width; ++imgCol) 76. { 77. CvScalar s = cvGet2D(bitImage, imgRow, imgCol); 78. 79. if (s.val[0] == 255) 80. { 81. s = cvGet2D(img, imgRow, imgCol); 82. RMean += s.val[2]; 83. GMean += s.val[1]; 84. BMean += s.val[0]; 85. 86. s = cvGet2D(hsvImage, imgRow, imgCol); 87. HMean += s.val[0]; 88. SMean += s.val[1]; 89. VMean += s.val[2]; 90. 91. ++nCount; 92. } 93. } 94. }// end RGB,HSV for 95. 96. RMean /= nCount; 97. GMean /= nCount; 98. BMean /= nCount; 99. 100. HMean /= nCount; 101. SMean /= nCount; 102. VMean /= nCount; 103. 104. 105. cvSetReal2D(mat, rows, 1, RMean); 106. cvSetReal2D(mat, rows, 2, GMean); 107. cvSetReal2D(mat, rows, 3, BMean); 108. cvSetReal2D(mat, rows, 4, HMean); 109. cvSetReal2D(mat, rows, 5, SMean); 110. cvSetReal2D(mat, rows, 6, VMean); 111. 112. //7个不变矩 113. 114. CvMoments moments; 115. cvMoments(contour, &moments, 1); 116. CvHuMoments huMoments; 117. cvGetHuMoments(&moments, &huMoments); 118. 119. double hu1 = huMoments.hu1; 120. double hu2 = huMoments.hu2; 121. double hu3 = huMoments.hu3; 122. double hu4 = huMoments.hu4; 123. double hu5 = huMoments.hu5; 124. double hu6 = huMoments.hu6; 125. double hu7 = huMoments.hu7; 126. 127. cvSetReal2D(mat, rows, 7, hu1); 128. cvSetReal2D(mat, rows, 8, hu2); 129. cvSetReal2D(mat, rows, 9, hu3); 130. cvSetReal2D(mat, rows, 10, hu4); 131. cvSetReal2D(mat, rows, 11, hu5); 132. cvSetReal2D(mat, rows, 12, hu6); 133. cvSetReal2D(mat, rows, 13, hu7); 134. }// end if 135. } 136. 137. showImage(hsvImage, IDC_PIC3); 138. showImage(bitImage, IDC_PIC2); 139. 140. 141. //释放内存 142. cvReleaseMemStorage(&storage); 143. storage = nullptr; 144. cvReleaseImage(&bitImage); 145. bitImage = nullptr; 146. cvReleaseImage(&grayImage); 147. grayImage = nullptr; 148. cvReleaseImage(&hsvImage); 149. hsvImage = nullptr; 150. } 151. 152. //释放内存 153. cvReleaseImage(&img); 154. img = nullptr; 155.}
Bayes训练代码:
[cpp] view plain copy 1.Book* book = xlCreateXMLBookW(); 2. 3. CvMat* dataMat = NULL; 4. 5. if (book->load(L"Data.xlsx")) 6. { 7. Sheet *sheet = book->getSheet(0); 8. 9. int myrow = sheet->lastRow(); 10. int mycol = sheet->lastCol(); 11. 12. if (sheet) 13. { 14. CvMat* importMat = cvCreateMat(myrow, mycol, CV_32FC1); //存储导入数据 15. 16. for (auto i = 0; i < myrow; ++i) 17. { 18. for (auto j = 0; j < mycol; j++) 19. { 20. double temp = sheet->readNum(i, j); 21. cvSetReal2D(importMat, i, j, temp); 22. } 23. }// end for 24. 25. dataMat = cvCloneMat(importMat); 26. }// end if 27. } 28. 29. book->release(); 30. 31. MessageBox(L"数据导入完成"); 32. 33. CvMat* lableMat = cvCreateMat(dataMat->rows, 1, CV_32FC1); //构建样本的分类标签 34. cvZero(lableMat); 35. 36. for (int i = 0; i < 4; ++i) //共分了 20 个不同的种类 37. { 38. for (int j = 0; j < 10; ++j) //每个品种共50个籽粒 39. { 40. cvSetReal2D(lableMat, i * 10 + j, 0, i + 1); 41. } 42. } 43. 44. CvNormalBayesClassifier nbc; 45. nbc.train(dataMat, lableMat); 46. nbc.save("bayes.txt"); 47. 48. MessageBox(L"数据训练完成"); 49. 50. CvMat* nbcResult = cvCreateMat(dataMat->rows, 1, CV_32FC1); 51. CvMat* nbcRow = NULL; 52. 53. for (int i = 0; i < dataMat->rows; ++i) 54. { 55. nbcRow = cvCreateMat(1, dataMat->cols, CV_32FC1); 56. 57. for (int j = 0; j < dataMat->cols; ++j) 58. { 59. float temp = cvGetReal2D(dataMat, i, j); 60. cvSetReal2D(nbcRow, 0, j, temp); 61. } 62. 63. unsigned int ret = 0; 64. ret = nbc.predict(nbcRow); 65. cvSetReal2D(nbcResult, i, 0, ret); 66. cvReleaseMat(&nbcRow); 67. nbcRow = NULL; 68. } 69. 70. int nCount = 0; 71. 72. for (int i = 0; i < 4; ++i) 73. { 74. for (int j = 0; j < 10; ++j) 75. { 76. int ret = cvGetReal2D(nbcResult, i * 10 + j, 0); 77. if (ret == (i + 1)) 78. { 79. ++nCount; 80. } 81. } 82. } 83. 84. float recognize = 100 * nCount / 10 / 4; 85. 86. CString str; 87. str.Format(L"朴素贝叶斯 识别率为: %f", recognize); 88. str = str + L"%"; 89. MessageBox(str); 90.
识别代码如下所示:
[html] view plain copy 1.CvNormalBayesClassifier nbc; 2. nbc.load("bayes.txt"); 3. 4. CFileDialog dlg(TRUE, NULL, NULL, 0, L"图片文件(*.jpg)|*.jpg||"); 5. if (dlg.DoModal() == IDOK) 6. { 7. USES_CONVERSION; 8. const char* loadPath = W2A(dlg.GetPathName()); 9. IplImage* testImage = cvLoadImage(loadPath); 10. 11. CvMat* mat = cvCreateMat(1, 14, CV_32FC1); 12. CollectCharacter(testImage, mat, 0); 13. 14. int ret = nbc.predict(mat); 15. CString str; 16. switch (ret) 17. { 18. case 1: 19. str = L"黑桃"; 20. break; 21. case 2: 22. str = "红桃"; 23. break; 24. case 3: 25. str = "梅花"; 26. break; 27. case 4: 28. str = "方块"; 29. break; 30. } 31. AfxMessageBox(str); 32. cvReleaseMat(&mat); 33. mat = NULL; 34. }//end if