Correlation Filter in Visual Tracking

     涉及两篇论文:Visual Object Tracking using Adaptive Correlation Filters 和Fast Visual Tracking via Dense Spatio-Temporal Context Learning

     可参考这位博主笔记:http://www.cnblogs.com/hanhuili/p/4266990.html

     第一篇我说下自己的理解:训练时的输出都认为是高斯形状,因为这种形状符合PSR。

                              训练得到模板后开始跟踪,由输出继续按照新的规则更校模板,进行跟踪。

      第二篇主要用到了上下文的信息,通过背景信息来确定目标的位置。可参考这篇博文:http://blog.csdn.net/zouxy09/article/details/16889905,博主还将其用C++实现了,很有启发性。

        STCTracker.h

// Fast object tracking algorithm  
// Author : zouxy  
// Date   : 2013-11-21  
// HomePage : http://blog.csdn.net/zouxy09  
// Email  : zouxy09@qq.com  
// Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning  
// HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/  
// Email: zhkhua@gmail.com   
#pragma once  
  
#include <opencv2/opencv.hpp>  
  
using namespace cv;  
using namespace std;  
  
class STCTracker  
{  
public:  
    STCTracker();  
    ~STCTracker();  
    void init(const Mat frame, const Rect box);   
    void tracking(const Mat frame, Rect &trackBox);  
  
private:  
    void createHammingWin();  
    void complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag = 0);  
    void getCxtPriorPosteriorModel(const Mat image);  
    void learnSTCModel(const Mat image);  
  
private:  
    double sigma;           // scale parameter (variance)  
    double alpha;           // scale parameter  
    double beta;            // shape parameter  
    double rho;             // learning parameter  
    Point center;           // the object position  
    Rect cxtRegion;         // context region  
      
    Mat cxtPriorPro;        // prior probability  
    Mat cxtPosteriorPro;    // posterior probability  
    Mat STModel;            // conditional probability  
    Mat STCModel;           // spatio-temporal context model  
    Mat hammingWin;         // Hamming window  
};  
View Code

       STCTracker.cpp

// Fast object tracking algorithm  
// Author : zouxy  
// Date   : 2013-11-21  
// HomePage : http://blog.csdn.net/zouxy09  
// Email  : zouxy09@qq.com  
// Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning  
// HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/  
// Email: zhkhua@gmail.com   
  
#include "STCTracker.h"  
  
STCTracker::STCTracker()  
{  
      
}  
  
STCTracker::~STCTracker()  
{  
  
}  
  
/************ Create a Hamming window ********************/  
void STCTracker::createHammingWin()  
{  
    for (int i = 0; i < hammingWin.rows; i++)  
    {  
        for (int j = 0; j < hammingWin.cols; j++)  
        {  
            hammingWin.at<double>(i, j) = (0.54 - 0.46 * cos( 2 * CV_PI * i / hammingWin.rows ))   
                                        * (0.54 - 0.46 * cos( 2 * CV_PI * j / hammingWin.cols ));  
        }  
    }  
}  
  
/************ Define two complex-value operation *****************/  
void STCTracker::complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag)  
{  
    CV_Assert(src1.size == src2.size);  
    CV_Assert(src1.channels() == 2);  
  
    Mat A_Real, A_Imag, B_Real, B_Imag, R_Real, R_Imag;  
    vector<Mat> planes;  
    split(src1, planes);  
    planes[0].copyTo(A_Real);  
    planes[1].copyTo(A_Imag);  
      
    split(src2, planes);  
    planes[0].copyTo(B_Real);  
    planes[1].copyTo(B_Imag);  
      
    dst.create(src1.rows, src1.cols, CV_64FC2);  
    split(dst, planes);  
    R_Real = planes[0];  
    R_Imag = planes[1];  
      
    for (int i = 0; i < A_Real.rows; i++)  
    {  
        for (int j = 0; j < A_Real.cols; j++)  
        {  
            double a = A_Real.at<double>(i, j);  
            double b = A_Imag.at<double>(i, j);  
            double c = B_Real.at<double>(i, j);  
            double d = B_Imag.at<double>(i, j);  
  
            if (flag)  
            {  
                // division: (a+bj) / (c+dj)  
                R_Real.at<double>(i, j) = (a * c + b * d) / (c * c + d * d + 0.000001);  
                R_Imag.at<double>(i, j) = (b * c - a * d) / (c * c + d * d + 0.000001);  
            }  
            else  
            {  
                // multiplication: (a+bj) * (c+dj)  
                R_Real.at<double>(i, j) = a * c - b * d;  
                R_Imag.at<double>(i, j) = b * c + a * d;  
            }  
        }  
    }  
    merge(planes, dst);  
}  
  
/************ Get context prior and posterior probability ***********/  
void STCTracker::getCxtPriorPosteriorModel(const Mat image)  
{  
    CV_Assert(image.size == cxtPriorPro.size);  
  
    double sum_prior(0), sum_post(0);  
    for (int i = 0; i < cxtRegion.height; i++)  
    {  
        for (int j = 0; j < cxtRegion.width; j++)  
        {  
            double x = j + cxtRegion.x;  
            double y = i + cxtRegion.y;  
            double dist = sqrt((center.x - x) * (center.x - x) + (center.y - y) * (center.y - y));  
  
            // equation (5) in the paper  
            cxtPriorPro.at<double>(i, j) = exp(- dist * dist / (2 * sigma * sigma));  
            sum_prior += cxtPriorPro.at<double>(i, j);  
  
            // equation (6) in the paper  
            cxtPosteriorPro.at<double>(i, j) = exp(- pow(dist / sqrt(alpha), beta));  
            sum_post += cxtPosteriorPro.at<double>(i, j);  
        }  
    }  
    cxtPriorPro.convertTo(cxtPriorPro, -1, 1.0/sum_prior);  
    cxtPriorPro = cxtPriorPro.mul(image);  
    cxtPosteriorPro.convertTo(cxtPosteriorPro, -1, 1.0/sum_post);  
}  
  
/************ Learn Spatio-Temporal Context Model ***********/  
void STCTracker::learnSTCModel(const Mat image)  
{  
    // step 1: Get context prior and posterior probability  
    getCxtPriorPosteriorModel(image);  
      
    // step 2-1: Execute 2D DFT for prior probability  
    Mat priorFourier;  
    Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)};  
    merge(planes1, 2, priorFourier);  
    dft(priorFourier, priorFourier);  
  
    // step 2-2: Execute 2D DFT for posterior probability  
    Mat postFourier;  
    Mat planes2[] = {cxtPosteriorPro, Mat::zeros(cxtPosteriorPro.size(), CV_64F)};  
    merge(planes2, 2, postFourier);  
    dft(postFourier, postFourier);  
  
    // step 3: Calculate the division  
    Mat conditionalFourier;  
    complexOperation(postFourier, priorFourier, conditionalFourier, 1);  
  
    // step 4: Execute 2D inverse DFT for conditional probability and we obtain STModel  
    dft(conditionalFourier, STModel, DFT_INVERSE | DFT_REAL_OUTPUT | DFT_SCALE);  
  
    // step 5: Use the learned spatial context model to update spatio-temporal context model  
    addWeighted(STCModel, 1.0 - rho, STModel, rho, 0.0, STCModel);  
}  
  
/************ Initialize the hyper parameters and models ***********/  
void STCTracker::init(const Mat frame, const Rect box)  
{  
    // initial some parameters  
    alpha = 2.25;  
    beta = 1;  
    rho = 0.075;  
    sigma = 0.5 * (box.width + box.height);  
  
    // the object position  
    center.x = box.x + 0.5 * box.width;  
    center.y = box.y + 0.5 * box.height;  
  
    // the context region  
    cxtRegion.width = 2 * box.width;  
    cxtRegion.height = 2 * box.height;  
    cxtRegion.x = center.x - cxtRegion.width * 0.5;  
    cxtRegion.y = center.y - cxtRegion.height * 0.5;  
    cxtRegion &= Rect(0, 0, frame.cols, frame.rows);  
  
    // the prior, posterior and conditional probability and spatio-temporal context model  
    cxtPriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);  
    cxtPosteriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);  
    STModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);  
    STCModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);  
  
    // create a Hamming window  
    hammingWin = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);  
    createHammingWin();  
  
    Mat gray;  
    cvtColor(frame, gray, CV_RGB2GRAY);  
  
    // normalized by subtracting the average intensity of that region  
    Scalar average = mean(gray(cxtRegion));  
    Mat context;  
    gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);  
  
    // multiplies a Hamming window to reduce the frequency effect of image boundary  
    context = context.mul(hammingWin);  
  
    // learn Spatio-Temporal context model from first frame  
    learnSTCModel(context);  
}  
  
/******** STCTracker: calculate the confidence map and find the max position *******/  
void STCTracker::tracking(const Mat frame, Rect &trackBox)  
{  
    Mat gray;  
    cvtColor(frame, gray, CV_RGB2GRAY);  
  
    // normalized by subtracting the average intensity of that region  
    Scalar average = mean(gray(cxtRegion));  
    Mat context;  
    gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);  
  
    // multiplies a Hamming window to reduce the frequency effect of image boundary  
    context = context.mul(hammingWin);  
  
    // step 1: Get context prior probability  
    getCxtPriorPosteriorModel(context);  
  
    // step 2-1: Execute 2D DFT for prior probability  
    Mat priorFourier;  
    Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)};  
    merge(planes1, 2, priorFourier);  
    dft(priorFourier, priorFourier);  
  
    // step 2-2: Execute 2D DFT for conditional probability  
    Mat STCModelFourier;  
    Mat planes2[] = {STCModel, Mat::zeros(STCModel.size(), CV_64F)};  
    merge(planes2, 2, STCModelFourier);  
    dft(STCModelFourier, STCModelFourier);  
  
    // step 3: Calculate the multiplication  
    Mat postFourier;  
    complexOperation(STCModelFourier, priorFourier, postFourier, 0);  
  
    // step 4: Execute 2D inverse DFT for posterior probability namely confidence map  
    Mat confidenceMap;  
    dft(postFourier, confidenceMap, DFT_INVERSE | DFT_REAL_OUTPUT| DFT_SCALE);  
  
    // step 5: Find the max position  
    Point point;  
    minMaxLoc(confidenceMap, 0, 0, 0, &point);  
  
    // step 6-1: update center, trackBox and context region  
    center.x = cxtRegion.x + point.x;  
    center.y = cxtRegion.y + point.y;  
    trackBox.x = center.x - 0.5 * trackBox.width;  
    trackBox.y = center.y - 0.5 * trackBox.height;  
    trackBox &= Rect(0, 0, frame.cols, frame.rows);  
  
    cxtRegion.x = center.x - cxtRegion.width * 0.5;  
    cxtRegion.y = center.y - cxtRegion.height * 0.5;  
    cxtRegion &= Rect(0, 0, frame.cols, frame.rows);  
  
    // step 7: learn Spatio-Temporal context model from this frame for tracking next frame  
    average = mean(gray(cxtRegion));  
    gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);  
    context = context.mul(hammingWin);  
    learnSTCModel(context);  
}  
View Code

     runTracker.cpp

    

// Fast object tracking algorithm  
// Author : zouxy  
// Date   : 2013-11-21  
// HomePage : http://blog.csdn.net/zouxy09  
// Email  : zouxy09@qq.com  
// Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning  
// HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/  
// Email: zhkhua@gmail.com   
  
#include "STCTracker.h"  
  
// Global variables  
Rect box;  
bool drawing_box = false;  
bool gotBB = false;  
  
// bounding box mouse callback  
void mouseHandler(int event, int x, int y, int flags, void *param){  
  switch( event ){  
  case CV_EVENT_MOUSEMOVE:  
    if (drawing_box){  
        box.width = x-box.x;  
        box.height = y-box.y;  
    }  
    break;  
  case CV_EVENT_LBUTTONDOWN:  
    drawing_box = true;  
    box = Rect( x, y, 0, 0 );  
    break;  
  case CV_EVENT_LBUTTONUP:  
    drawing_box = false;  
    if( box.width < 0 ){  
        box.x += box.width;  
        box.width *= -1;  
    }  
    if( box.height < 0 ){  
        box.y += box.height;  
        box.height *= -1;  
    }  
    gotBB = true;  
    break;  
  }  
}  
  
int main(int argc, char * argv[])  
{  
    VideoCapture capture;  
    capture.open("handwave.wmv");  
    bool fromfile = true;  
  
    if (!capture.isOpened())  
    {  
        cout << "capture device failed to open!" << endl;  
        return -1;  
    }  
    //Register mouse callback to draw the bounding box  
    cvNamedWindow("Tracker", CV_WINDOW_AUTOSIZE);  
    cvSetMouseCallback("Tracker", mouseHandler, NULL );   
  
    Mat frame;  
    capture >> frame;  
    while(!gotBB)  
    {  
        if (!fromfile)  
            capture >> frame;  
  
        imshow("Tracker", frame);  
        if (cvWaitKey(20) == 27)  
            return 1;  
    }  
    //Remove callback  
    cvSetMouseCallback("Tracker", NULL, NULL );   
      
    STCTracker stcTracker;  
    stcTracker.init(frame, box);  
  
    int frameCount = 0;  
    while (1)  
    {  
        capture >> frame;  
        if (frame.empty())  
            return -1;  
        double t = (double)cvGetTickCount();  
        frameCount++;  
  
        // tracking  
        stcTracker.tracking(frame, box);      
  
        // show the result  
        stringstream buf;  
        buf << frameCount;  
        string num = buf.str();  
        putText(frame, num, Point(20, 30), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 0, 255), 3);  
        rectangle(frame, box, Scalar(0, 0, 255), 3);  
        imshow("Tracker", frame);  
  
        t = (double)cvGetTickCount() - t;  
        cout << "cost time: " << t / ((double)cvGetTickFrequency()*1000.) << endl;  
  
        if ( cvWaitKey(1) == 27 )  
            break;  
    }  
  
    return 0;  
}  
View Code

              这篇论文的原码作者已经看出,非常的简洁,有空再研究。

               文中还将生成模型和判别模型进行了对比。生成模型一般就是学习一个外表模型来表示目标,然后寻找最匹配的图像区域。

               判别模型把跟踪作为一个分类问题,评估目标和背景的决策边界。

 

  为了便于理解,我把流程图画了下来,visio用的还不熟,不知道拐弯的箭头咋画,所以那个循环没画出来

  

   

  

 


  

  

  

posted @ 2015-06-02 10:17  牧马人夏峥  阅读(1514)  评论(0编辑  收藏  举报