boost强分类器的实现

boost.cpp文件下:

bool CvCascadeBoost::train( const CvFeatureEvaluator* _featureEvaluator,
                           int _numSamples,
                           int _precalcValBufSize, int _precalcIdxBufSize,
                           const CvCascadeBoostParams& _params )

函数是boost方法的入口函数。

    // 部分代码,设置参数
    set_params( _params );
    // 如果是logit或gentle的boost方式,需要从_featureEvaluator->cls 中拷贝样本的类别信息到 data->responses
    // 因为这两种boost方法计算式把类别从0/1该为-1/+1使用
    if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
        data->do_responses_copy();
    // 设置所有样本初始权值为1/n
    update_weights( 0 );

    cout << "+----+---------+---------+" << endl;
    cout << "|  N |    HR   |    FA   |" << endl;
    cout << "+----+---------+---------+" << endl;

    do
    {
        CvCascadeBoostTree* tree = new CvCascadeBoostTree;
        // 训练一个弱分类器,弱分类器是棵CART树 
        if( !tree->train( data, subsample_mask, this ) )
        {
            delete tree;
            break;
        }
        cvSeqPush( weak, &tree );
        // 根据boost公式更新样本数据的权值
        update_weights( tree );
        // 根据用户输入参数,把一定比例的(0.05)权值最小的样本去掉
        trim_weights();
        // subsample_mask 保存每个样本是否参数训练的标记(值为0/1);没有可用样本了,退出训练
        if( cvCountNonZero(subsample_mask) == 0 )
            break;
    }
    // 如果当前强分类器达到了设置的虚警率要求或弱分类数目达到上限停止
    while( !isErrDesired() && (weak->total < params.weak_count) );

后续的执行流程可以参见http://blog.csdn.net/beerbuddys/article/details/40712957

void CvDTree::try_split_node( CvDTreeNode* node )
{
    CvDTreeSplit* best_split = 0;
    int i, n = node->sample_count, vi;
    bool can_split = true;
    double quality_scale;

    calc_node_value( node );

    if( node->sample_count <= data->params.min_sample_count ||
        node->depth >= data->params.max_depth )
        can_split = false;

    if( can_split && data->is_classifier )
    {
        // check if we have a "pure" node,
        // we assume that cls_count is filled by calc_node_value()
        int* cls_count = data->counts->data.i;
        int nz = 0, m = data->get_num_classes();
        for( i = 0; i < m; i++ )
            nz += cls_count[i] != 0;
        if( nz == 1 ) // there is only one class
            can_split = false;
    }
    else if( can_split )
    {
        if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
            can_split = false;
    }

    if( can_split )
    {
        best_split = find_best_split(node);
        // TODO: check the split quality ...
        node->split = best_split;
    }
    if( !can_split || !best_split )
    {
        data->free_node_data(node);
        return;
    }

    quality_scale = calc_node_dir( node );
    if( data->params.use_surrogates )
    {
        // find all the surrogate splits
        // and sort them by their similarity to the primary one
        for( vi = 0; vi < data->var_count; vi++ )
        {
            CvDTreeSplit* split;
            int ci = data->get_var_type(vi);

            if( vi == best_split->var_idx )
                continue;

            if( ci >= 0 )
                split = find_surrogate_split_cat( node, vi );
            else
                split = find_surrogate_split_ord( node, vi );

            if( split )
            {
                // insert the split
                CvDTreeSplit* prev_split = node->split;
                split->quality = (float)(split->quality*quality_scale);

                while( prev_split->next &&
                       prev_split->next->quality > split->quality )
                    prev_split = prev_split->next;
                split->next = prev_split->next;
                prev_split->next = split;
            }
        }
    }
    split_node_data( node );
    // 为结点的左右计算输出值
    try_split_node( node->left );
    try_split_node( node->right );
}

其中calc_node_value计算结点的value,对应代码是

void
CvBoostTree::calc_node_value( CvDTreeNode* node )

然后执行到tree.cpp中的:

CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
{
    DTreeBestSplitFinder finder( this, node );
    // 在开启TBB情况下,多核并行处理
    cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);

    CvDTreeSplit *bestSplit = 0;
    if( finder.bestSplit->quality > 0 )
    {
        bestSplit = data->new_split_cat( 0, -1.0f );
        memcpy( bestSplit, finder.bestSplit, finder.splitSize );
    }

    return bestSplit;
}

进一步看operator()函数

//tree->find_split_ord_reg函数对特征vi找到最优的阈值。
void DTreeBestSplitFinder::operator()(const BlockedRange& range)
{
    int vi, vi1 = range.begin(), vi2 = range.end();
    int n = node->sample_count;
    CvDTreeTrainData* data = tree->get_data();
    AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));

    for( vi = vi1; vi < vi2; vi++ )
    {
        CvDTreeSplit *res;
        int ci = data->get_var_type(vi);
        if( node->get_num_valid(vi) <= 1 )
            continue;

        if( data->is_classifier )
        {
            if( ci >= 0 )
                res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
            else
                res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
        }
        else
        {
            if( ci >= 0 )
                res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
            else // 找到特征vi对应的最优分割,也就是求取最优阈值
                res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
        }
        // 更新bestSplit为quality最高的分割 
        if( res && bestSplit->quality < split->quality )
                memcpy( bestSplit.get(), split.get(), splitSize );
    }
}

find_split_ord_reg要做的事情就是寻找最优分割,找到一个阈值将数据分为两部分,并保证两边的总体误差最小。

策略是:将特征值排序,然后依次测试最优阈值为values[i]和values[i+1]的中值。

CvDTreeSplit*
CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
{
    const float epsilon = FLT_EPSILON*2;
    const double* weights = ensemble->get_subtree_weights()->data.db;
    int n = node->sample_count;
    int n1 = node->get_num_valid(vi);

    cv::AutoBuffer<uchar> inn_buf;
    if( !_ext_buf )
        inn_buf.allocate(2*n*(sizeof(int)+sizeof(float)));
    uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;

    float* values_buf = (float*)ext_buf;
    int* indices_buf = (int*)(values_buf + n);
    int* sample_indices_buf = indices_buf + n;
    const float* values = 0;
    const int* indices = 0;
    // 计算所有样本的第vi个haar特征值,values为特征值数组,已经从小到大排序  
    data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices, sample_indices_buf );
    float* responses_buf = (float*)(indices_buf + n);
    const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );

    int i, best_i = -1;
    double L = 0, R = weights[n];
    double best_val = init_quality, lsum = 0, rsum = node->value*R;

    // compensate for missing values
    for( i = n1; i < n; i++ )
    {
        int idx = indices[i];
        double w = weights[idx];
        rsum -= responses[idx]*w;
        R -= w;
    }

    // find the optimal split
    for( i = 0; i < n1 - 1; i++ )
    {
        int idx = indices[i];
        double w = weights[idx];
        double t = responses[idx]*w;
        L += w; R -= w;
        lsum += t; rsum -= t;

        if( values[i] + epsilon < values[i+1] )
        {
            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
            if( best_val < val )
            {
                best_val = val;
                best_i = i;
            }
        }
    }

    CvDTreeSplit* split = 0;
    if( best_i >= 0 )
    {
        split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
        split->var_idx = vi;
        split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
        split->ord.split_point = best_i;
        split->inversed = 0;
        split->quality = (float)best_val;
    }
    return split;
}

类似与左右的熵越低越好。特征的计算见函数:

void CvCascadeBoostTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf,
        const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf )
posted @ 2015-10-09 09:25  侯凯  阅读(863)  评论(0编辑  收藏  举报