opencv haartraining 分析三:icvC…

本文介绍了一个用于创建CART弱分类器的函数,该函数接受训练数据、样本索引等参数,并返回一个CART弱分类器。文章详细阐述了构建过程中的关键步骤和技术细节。

CvIntHaarClassifier* icvCreateCARTStageClassifier( CvHaarTrainingData* data,
                                                   CvMat* sampleIdx,
                                                   CvIntHaarFeatures* haarFeatures,
                                                   float minhitrate,
                                                   float maxfalsealarm,
                                                   int   symmetric,
                                                   float weightfraction,
                                                   int numsplits,
                                                   CvBoostType boosttype,
                                                   CvStumpError stumperror,
                                                   int maxsplits )

{

#ifdef CV_COL_ARRANGEMENT
    int flags = CV_COL_SAMPLE;
#else
    int flags = CV_ROW_SAMPLE;
#endif

    CvStageHaarClassifier* stage = NULL;
    CvBoostTrainer* trainer;
    CvCARTClassifier* cart = NULL;
    CvCARTTrainParams trainParams;
    CvMTStumpTrainParams stumpTrainParams;
    //CvMat* trainData = NULL;
    //CvMat* sortedIdx = NULL;
    CvMat eval;
    int n = 0;
    int m = 0;
    int numpos = 0;
    int numneg = 0;
    int numfalse = 0;
    float sum_stage = 0.0F;
    float threshold = 0.0F;
    float falsealarm = 0.0F;
   
    //CvMat* sampleIdx = NULL;
    CvMat* trimmedIdx;
    //float* idxdata = NULL;
    //float* tempweights = NULL;
    //int    idxcount = 0;
    CvUserdata userdata;

    int i = 0;
    int j = 0;
    int idx;
    int numsamples;
    int numtrimmed;
   
    CvCARTHaarClassifier* classifier;
    CvSeq* seq = NULL;
    CvMemStorage* storage = NULL;
    CvMat* weakTrainVals;
    float alpha;
    float sumalpha;
    int num_splits;

#ifdef CV_VERBOSE
    printf( "+----+----+-+---------+---------+---------+---------+\n" );
    printf( "|  N |%%SMP|F|  ST.THR |    HR     FA   | EXP. ERR|\n" );
    printf( "+----+----+-+---------+---------+---------+---------+\n" );
#endif
   
    n = haarFeatures->count;//这是haar特征的数目,对于32*32的子窗口,特征数目为26万多
    m = data->sum.rows;
    numsamples = (sampleIdx) ? MAX( sampleIdx->rows, sampleIdx->cols ) : m;

    userdata = cvUserdata( data, haarFeatures );

    stumpTrainParams.type = ( boosttype == CV_DABCLASS )
        ? CV_CLASSIFICATION_CLASS : CV_REGRESSION;
    stumpTrainParams.error = ( boosttype == CV_LBCLASS || boosttype == CV_GABCLASS )
        ? CV_SQUARE : stumperror;
    stumpTrainParams.portion = CV_STUMP_TRAIN_PORTION;
    stumpTrainParams.getTrainData = icvGetTrainingDataCallback;
    stumpTrainParams.numcomp = n;
    stumpTrainParams.userdata = &userdata;
    stumpTrainParams.sortedIdx = data->idxcache;//这是对构建cart的每个节点的stump一级决策树参数的设置

    trainParams.count = numsplits;
    trainParams.stumpTrainParams = (CvClassifierTrainParams*) &stumpTrainParams;
    trainParams.stumpConstructor = cvCreateMTStumpClassifier;
    trainParams.splitIdx = icvSplitIndicesCallback;
    trainParams.userdata = &userdata;//这是对cart弱分类器参数的设置

    eval = cvMat( 1, m, CV_32FC1, cvAlloc( sizeof( float ) * m ) );
   
    storage = cvCreateMemStorage();
    seq = cvCreateSeq( 0, sizeof( *seq ), sizeof( classifier ), storage );

    weakTrainVals = cvCreateMat( 1, m, CV_32FC1 );
    trainer = cvBoostStartTraining( &data->cls, weakTrainVals, &data->weights,
                                    sampleIdx, boosttype );//这是用data->cls来计算weakTrainVals。其中weakTrainVals=2*cls-1,cls属于{0,1},则weakTrainVals属于{-1,1}
    num_splits = 0;
    sumalpha = 0.0F;
    do
      

#ifdef CV_VERBOSE
        int v_wt = 0;
        int v_flipped = 0;
#endif

        trimmedIdx = cvTrimWeights( &data->weights, sampleIdx, weightfraction );//剔除小权值,由weightfraction来控制。
        numtrimmed = (trimmedIdx) ? MAX( trimmedIdx->rows, trimmedIdx->cols ) : m;

#ifdef CV_VERBOSE
        v_wt = 100 * numtrimmed / numsamples;
        v_flipped = 0;

#endif

        cart = (CvCARTClassifier*) cvCreateCARTClassifier( data->valcache,
                        flags,
                        weakTrainVals, 0, 0, 0, trimmedIdx,
                        &(data->weights),
                        (CvClassifierTrainParams*) &trainParams );//开始构建cart树弱分类器

        classifier = (CvCARTHaarClassifier*) icvCreateCARTHaarClassifier( numsplits );
        icvInitCARTHaarClassifier( classifier, cart, haarFeatures );

        num_splits += classifier->count;

        cart->release( (CvClassifier**) &cart );
       
        if( symmetric && (seq->total % 2) )
        {
            float normfactor = 0.0F;
            CvStumpClassifier* stump;
           
           
            for( i = 0; i < classifier->count; i++ )
            {
                if( classifier->feature[i].desc[0] == 'h' )
                {
                    for( j = 0; j < CV_HAAR_FEATURE_MAX &&
                                    classifier->feature[i].rect[j].weight != 0.0F; j++ )
                    {
                        classifier->feature[i].rect[j].r.x = data->winsize.width -
                            classifier->feature[i].rect[j].r.x -
                            classifier->feature[i].rect[j].r.width;               
                    }
                }
                else
                {
                    int tmp = 0;

                   
                   
                    for( j = 0; j < CV_HAAR_FEATURE_MAX &&
                                    classifier->feature[i].rect[j].weight != 0.0F; j++ )
                    {
                        classifier->feature[i].rect[j].r.x = data->winsize.width -
                            classifier->feature[i].rect[j].r.x;
                        CV_SWAP( classifier->feature[i].rect[j].r.width,
                                 classifier->feature[i].rect[j].r.height, tmp );
                    }
                }
            }
            icvConvertToFastHaarFeature( classifier->feature,
                                         classifier->fastfeature,
                                         classifier->count, data->winsize.width + 1 );

            stumpTrainParams.getTrainData = NULL;
            stumpTrainParams.numcomp = 1;
            stumpTrainParams.userdata = NULL;
            stumpTrainParams.sortedIdx = NULL;

            for( i = 0; i < classifier->count; i++ )
            {
                for( j = 0; j < numtrimmed; j++ )
                {
                    idx = icvGetIdxAt( trimmedIdx, j );

                    eval.data.fl[idx] = cvEvalFastHaarFeature( &classifier->fastfeature[i],
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step) );
                    normfactor = data->normfactor.data.fl[idx];
                    eval.data.fl[idx] = ( normfactor == 0.0F )
                        ? 0.0F : (eval.data.fl[idx] / normfactor);
                }

                stump = (CvStumpClassifier*) trainParams.stumpConstructor( &eval,
                    CV_COL_SAMPLE,
                    weakTrainVals, 0, 0, 0, trimmedIdx,
                    &(data->weights),
                    trainParams.stumpTrainParams );
           
                classifier->threshold[i] = stump->threshold;
                if( classifier->left[i] <= 0 )
                {
                    classifier->val[-classifier->left[i]] = stump->left;
                }
                if( classifier->right[i] <= 0 )
                {
                    classifier->val[-classifier->right[i]] = stump->right;
                }

                stump->release( (CvClassifier**) &stump );       
               
            }

            stumpTrainParams.getTrainData = icvGetTrainingDataCallback;
            stumpTrainParams.numcomp = n;
            stumpTrainParams.userdata = &userdata;
            stumpTrainParams.sortedIdx = data->idxcache;

#ifdef CV_VERBOSE
            v_flipped = 1;
#endif

        }
        if( trimmedIdx != sampleIdx )
        {
            cvReleaseMat( &trimmedIdx );
            trimmedIdx = NULL;
        }
       
        for( i = 0; i < numsamples; i++ )
        {
            idx = icvGetIdxAt( sampleIdx, i );

            eval.data.fl[idx] = classifier->eval_r( (CvIntHaarClassifier*) classifier,
                (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                data->normfactor.data.fl[idx] );
        }

        alpha = cvBoostNextWeakClassifier( &eval, &data->cls, weakTrainVals,
                                           &data->weights, trainer );
        sumalpha += alpha;
       
        for( i = 0; i <= classifier->count; i++ )
        {
            if( boosttype == CV_RABCLASS )
            {
                classifier->val[i] = cvLogRatio( classifier->val[i] );
            }
            classifier->val[i] *= alpha;
        }

        cvSeqPush( seq, (void*) &classifier );

        numpos = 0;
        for( i = 0; i < numsamples; i++ )
        {
            idx = icvGetIdxAt( sampleIdx, i );

            if( data->cls.data.fl[idx] == 1.0F )
            {
                eval.data.fl[numpos] = 0.0F;
                for( j = 0; j < seq->total; j++ )
                {
                    classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));
                    eval.data.fl[numpos] += classifier->eval_r(
                        (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
               
                numpos++;
            }
        }
        icvSort_32f( eval.data.fl, numpos, 0 );
        threshold = eval.data.fl[(int) ((1.0F - minhitrate) * numpos)];

        numneg = 0;
        numfalse = 0;
        for( i = 0; i < numsamples; i++ )
        {
            idx = icvGetIdxAt( sampleIdx, i );

            if( data->cls.data.fl[idx] == 0.0F )
            {
                numneg++;
                sum_stage = 0.0F;
                for( j = 0; j < seq->total; j++ )
                {
                   classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));
                   sum_stage += classifier->eval_r( (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
               
                if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )
                {
                    numfalse++;
                }
            }
        }
        falsealarm = ((float) numfalse) / ((float) numneg);

#ifdef CV_VERBOSE
        {
            float v_hitrate    = 0.0F;
            float v_falsealarm = 0.0F;
           
            float v_experr = 0.0F;

            for( i = 0; i < numsamples; i++ )
            {
                idx = icvGetIdxAt( sampleIdx, i );

                sum_stage = 0.0F;
                for( j = 0; j < seq->total; j++ )
                {
                    classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));
                    sum_stage += classifier->eval_r( (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
               
                if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )
                {
                    if( data->cls.data.fl[idx] == 1.0F )
                    {
                        v_hitrate += 1.0F;
                    }
                    else
                    {
                        v_falsealarm += 1.0F;
                    }
                }
                if( ( sum_stage >= 0.0F ) != (data->cls.data.fl[idx] == 1.0F) )
                {
                    v_experr += 1.0F;
                }
            }
            v_experr /= numsamples;
            printf( "|M|=%%|%c|�|�|�|�|\n",
                seq->total, v_wt, ( (v_flipped) ? '+' : '-' ),
                threshold, v_hitrate / numpos, v_falsealarm / numneg,
                v_experr );
            printf( "+----+----+-+---------+---------+---------+---------+\n" );
            fflush( stdout );
        }
#endif
       
    } while( falsealarm > maxfalsealarm && (!maxsplits || (num_splits < maxsplits) ) );
    cvBoostEndTraining( &trainer );

    if( falsealarm > maxfalsealarm )
    {
        stage = NULL;
    }
    else
    {
        stage = (CvStageHaarClassifier*) icvCreateStageHaarClassifier( seq->total,
                                                                       threshold );
        cvCvtSeqToArray( seq, (CvArr*) stage->classifier );
    }
   
   
    cvReleaseMemStorage( &storage );
    cvReleaseMat( &weakTrainVals );
    cvFree( &(eval.data.ptr) );
   
    return (CvIntHaarClassifier*) stage;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值