Correlation Filter in Visual Tracking
涉及两篇论文:Visual Object Tracking using Adaptive Correlation Filters 和Fast Visual Tracking via Dense Spatio-Temporal Context Learning
可参考这位博主笔记:http://www.cnblogs.com/hanhuili/p/4266990.html
第一篇我说下自己的理解:训练时的输出都认为是高斯形状,因为这种形状符合PSR。
训练得到模板后开始跟踪,由输出继续按照新的规则更校模板,进行跟踪。
第二篇主要用到了上下文的信息,通过背景信息来确定目标的位置。可参考这篇博文:http://blog.csdn.net/zouxy09/article/details/16889905,博主还将其用C++实现了,很有启发性。
STCTracker.h
// Fast object tracking algorithm // Author : zouxy // Date : 2013-11-21 // HomePage : http://blog.csdn.net/zouxy09 // Email : zouxy09@qq.com // Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning // HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/ // Email: zhkhua@gmail.com #pragma once #include <opencv2/opencv.hpp> using namespace cv; using namespace std; class STCTracker { public: STCTracker(); ~STCTracker(); void init(const Mat frame, const Rect box); void tracking(const Mat frame, Rect &trackBox); private: void createHammingWin(); void complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag = 0); void getCxtPriorPosteriorModel(const Mat image); void learnSTCModel(const Mat image); private: double sigma; // scale parameter (variance) double alpha; // scale parameter double beta; // shape parameter double rho; // learning parameter Point center; // the object position Rect cxtRegion; // context region Mat cxtPriorPro; // prior probability Mat cxtPosteriorPro; // posterior probability Mat STModel; // conditional probability Mat STCModel; // spatio-temporal context model Mat hammingWin; // Hamming window };
STCTracker.cpp
// Fast object tracking algorithm // Author : zouxy // Date : 2013-11-21 // HomePage : http://blog.csdn.net/zouxy09 // Email : zouxy09@qq.com // Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning // HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/ // Email: zhkhua@gmail.com #include "STCTracker.h" STCTracker::STCTracker() { } STCTracker::~STCTracker() { } /************ Create a Hamming window ********************/ void STCTracker::createHammingWin() { for (int i = 0; i < hammingWin.rows; i++) { for (int j = 0; j < hammingWin.cols; j++) { hammingWin.at<double>(i, j) = (0.54 - 0.46 * cos( 2 * CV_PI * i / hammingWin.rows )) * (0.54 - 0.46 * cos( 2 * CV_PI * j / hammingWin.cols )); } } } /************ Define two complex-value operation *****************/ void STCTracker::complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag) { CV_Assert(src1.size == src2.size); CV_Assert(src1.channels() == 2); Mat A_Real, A_Imag, B_Real, B_Imag, R_Real, R_Imag; vector<Mat> planes; split(src1, planes); planes[0].copyTo(A_Real); planes[1].copyTo(A_Imag); split(src2, planes); planes[0].copyTo(B_Real); planes[1].copyTo(B_Imag); dst.create(src1.rows, src1.cols, CV_64FC2); split(dst, planes); R_Real = planes[0]; R_Imag = planes[1]; for (int i = 0; i < A_Real.rows; i++) { for (int j = 0; j < A_Real.cols; j++) { double a = A_Real.at<double>(i, j); double b = A_Imag.at<double>(i, j); double c = B_Real.at<double>(i, j); double d = B_Imag.at<double>(i, j); if (flag) { // division: (a+bj) / (c+dj) R_Real.at<double>(i, j) = (a * c + b * d) / (c * c + d * d + 0.000001); R_Imag.at<double>(i, j) = (b * c - a * d) / (c * c + d * d + 0.000001); } else { // multiplication: (a+bj) * (c+dj) R_Real.at<double>(i, j) = a * c - b * d; R_Imag.at<double>(i, j) = b * c + a * d; } } } merge(planes, dst); } /************ Get context prior and posterior probability ***********/ void STCTracker::getCxtPriorPosteriorModel(const Mat image) { CV_Assert(image.size == cxtPriorPro.size); double sum_prior(0), sum_post(0); for (int i = 0; i < cxtRegion.height; i++) { for (int j = 0; j < cxtRegion.width; j++) { double x = j + cxtRegion.x; double y = i + cxtRegion.y; double dist = sqrt((center.x - x) * (center.x - x) + (center.y - y) * (center.y - y)); // equation (5) in the paper cxtPriorPro.at<double>(i, j) = exp(- dist * dist / (2 * sigma * sigma)); sum_prior += cxtPriorPro.at<double>(i, j); // equation (6) in the paper cxtPosteriorPro.at<double>(i, j) = exp(- pow(dist / sqrt(alpha), beta)); sum_post += cxtPosteriorPro.at<double>(i, j); } } cxtPriorPro.convertTo(cxtPriorPro, -1, 1.0/sum_prior); cxtPriorPro = cxtPriorPro.mul(image); cxtPosteriorPro.convertTo(cxtPosteriorPro, -1, 1.0/sum_post); } /************ Learn Spatio-Temporal Context Model ***********/ void STCTracker::learnSTCModel(const Mat image) { // step 1: Get context prior and posterior probability getCxtPriorPosteriorModel(image); // step 2-1: Execute 2D DFT for prior probability Mat priorFourier; Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)}; merge(planes1, 2, priorFourier); dft(priorFourier, priorFourier); // step 2-2: Execute 2D DFT for posterior probability Mat postFourier; Mat planes2[] = {cxtPosteriorPro, Mat::zeros(cxtPosteriorPro.size(), CV_64F)}; merge(planes2, 2, postFourier); dft(postFourier, postFourier); // step 3: Calculate the division Mat conditionalFourier; complexOperation(postFourier, priorFourier, conditionalFourier, 1); // step 4: Execute 2D inverse DFT for conditional probability and we obtain STModel dft(conditionalFourier, STModel, DFT_INVERSE | DFT_REAL_OUTPUT | DFT_SCALE); // step 5: Use the learned spatial context model to update spatio-temporal context model addWeighted(STCModel, 1.0 - rho, STModel, rho, 0.0, STCModel); } /************ Initialize the hyper parameters and models ***********/ void STCTracker::init(const Mat frame, const Rect box) { // initial some parameters alpha = 2.25; beta = 1; rho = 0.075; sigma = 0.5 * (box.width + box.height); // the object position center.x = box.x + 0.5 * box.width; center.y = box.y + 0.5 * box.height; // the context region cxtRegion.width = 2 * box.width; cxtRegion.height = 2 * box.height; cxtRegion.x = center.x - cxtRegion.width * 0.5; cxtRegion.y = center.y - cxtRegion.height * 0.5; cxtRegion &= Rect(0, 0, frame.cols, frame.rows); // the prior, posterior and conditional probability and spatio-temporal context model cxtPriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1); cxtPosteriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1); STModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1); STCModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1); // create a Hamming window hammingWin = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1); createHammingWin(); Mat gray; cvtColor(frame, gray, CV_RGB2GRAY); // normalized by subtracting the average intensity of that region Scalar average = mean(gray(cxtRegion)); Mat context; gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]); // multiplies a Hamming window to reduce the frequency effect of image boundary context = context.mul(hammingWin); // learn Spatio-Temporal context model from first frame learnSTCModel(context); } /******** STCTracker: calculate the confidence map and find the max position *******/ void STCTracker::tracking(const Mat frame, Rect &trackBox) { Mat gray; cvtColor(frame, gray, CV_RGB2GRAY); // normalized by subtracting the average intensity of that region Scalar average = mean(gray(cxtRegion)); Mat context; gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]); // multiplies a Hamming window to reduce the frequency effect of image boundary context = context.mul(hammingWin); // step 1: Get context prior probability getCxtPriorPosteriorModel(context); // step 2-1: Execute 2D DFT for prior probability Mat priorFourier; Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)}; merge(planes1, 2, priorFourier); dft(priorFourier, priorFourier); // step 2-2: Execute 2D DFT for conditional probability Mat STCModelFourier; Mat planes2[] = {STCModel, Mat::zeros(STCModel.size(), CV_64F)}; merge(planes2, 2, STCModelFourier); dft(STCModelFourier, STCModelFourier); // step 3: Calculate the multiplication Mat postFourier; complexOperation(STCModelFourier, priorFourier, postFourier, 0); // step 4: Execute 2D inverse DFT for posterior probability namely confidence map Mat confidenceMap; dft(postFourier, confidenceMap, DFT_INVERSE | DFT_REAL_OUTPUT| DFT_SCALE); // step 5: Find the max position Point point; minMaxLoc(confidenceMap, 0, 0, 0, &point); // step 6-1: update center, trackBox and context region center.x = cxtRegion.x + point.x; center.y = cxtRegion.y + point.y; trackBox.x = center.x - 0.5 * trackBox.width; trackBox.y = center.y - 0.5 * trackBox.height; trackBox &= Rect(0, 0, frame.cols, frame.rows); cxtRegion.x = center.x - cxtRegion.width * 0.5; cxtRegion.y = center.y - cxtRegion.height * 0.5; cxtRegion &= Rect(0, 0, frame.cols, frame.rows); // step 7: learn Spatio-Temporal context model from this frame for tracking next frame average = mean(gray(cxtRegion)); gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]); context = context.mul(hammingWin); learnSTCModel(context); }
runTracker.cpp
// Fast object tracking algorithm // Author : zouxy // Date : 2013-11-21 // HomePage : http://blog.csdn.net/zouxy09 // Email : zouxy09@qq.com // Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning // HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/ // Email: zhkhua@gmail.com #include "STCTracker.h" // Global variables Rect box; bool drawing_box = false; bool gotBB = false; // bounding box mouse callback void mouseHandler(int event, int x, int y, int flags, void *param){ switch( event ){ case CV_EVENT_MOUSEMOVE: if (drawing_box){ box.width = x-box.x; box.height = y-box.y; } break; case CV_EVENT_LBUTTONDOWN: drawing_box = true; box = Rect( x, y, 0, 0 ); break; case CV_EVENT_LBUTTONUP: drawing_box = false; if( box.width < 0 ){ box.x += box.width; box.width *= -1; } if( box.height < 0 ){ box.y += box.height; box.height *= -1; } gotBB = true; break; } } int main(int argc, char * argv[]) { VideoCapture capture; capture.open("handwave.wmv"); bool fromfile = true; if (!capture.isOpened()) { cout << "capture device failed to open!" << endl; return -1; } //Register mouse callback to draw the bounding box cvNamedWindow("Tracker", CV_WINDOW_AUTOSIZE); cvSetMouseCallback("Tracker", mouseHandler, NULL ); Mat frame; capture >> frame; while(!gotBB) { if (!fromfile) capture >> frame; imshow("Tracker", frame); if (cvWaitKey(20) == 27) return 1; } //Remove callback cvSetMouseCallback("Tracker", NULL, NULL ); STCTracker stcTracker; stcTracker.init(frame, box); int frameCount = 0; while (1) { capture >> frame; if (frame.empty()) return -1; double t = (double)cvGetTickCount(); frameCount++; // tracking stcTracker.tracking(frame, box); // show the result stringstream buf; buf << frameCount; string num = buf.str(); putText(frame, num, Point(20, 30), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 0, 255), 3); rectangle(frame, box, Scalar(0, 0, 255), 3); imshow("Tracker", frame); t = (double)cvGetTickCount() - t; cout << "cost time: " << t / ((double)cvGetTickFrequency()*1000.) << endl; if ( cvWaitKey(1) == 27 ) break; } return 0; }
这篇论文的原码作者已经看出,非常的简洁,有空再研究。
文中还将生成模型和判别模型进行了对比。生成模型一般就是学习一个外表模型来表示目标,然后寻找最匹配的图像区域。
判别模型把跟踪作为一个分类问题,评估目标和背景的决策边界。
为了便于理解,我把流程图画了下来,visio用的还不熟,不知道拐弯的箭头咋画,所以那个循环没画出来