直接上代码:
// 基于RANSAC算法的直线拟合
// pstData: 指向存储数据的指针
// dataCnt: 数据点个数
// lineParameterK: 直线的斜率
// lineParameterB: 直线的截距
// minCnt: 模型(直线)参数估计所需的数据点的个数
// maxIterCnt: 最大迭代次数
// maxErrorThreshold: 最大误差阈值
// consensusCntThreshold: 模型一致性判断准则
// modelMeanError: 模型误差
// 返回值: 返回0表示获取最优模型, 否则表示未获取最优模型
//if(!ransacLiner(dataPoints, totalCnt, 2, 50, 35, 0.1, A, B, C, meanError))
int ransacLiner(st_Point* pstData, int dataCnt, int minCnt, double maxIterCnt, int consensusCntThreshold,
double maxErrorThreshold, double& A, double& B, double& C, double& modelMeanError,
set<unsigned int> &consensusIndexs)
{
default_random_engine rng;
uniform_int_distribution<unsigned> uniform(0, dataCnt - 1);
rng.seed(8); // 固定随机数种子
set<unsigned int> selectIndexs; // 选择的点的索引
vector<st_Point> selectPoints; // 选择的点
double temp_A = 0;
double temp_B = 0;
double temp_C = 0;
modelMeanError = 0;
int isNonFind = 1;
unsigned int bestConsensusCnt = 0; // 满足一致性估计的点的个数
int iter = 0;
while(iter < maxIterCnt)
{
selectIndexs.clear();
selectPoints.clear();
// Step1: 随机选择minCnt个点
while(1)
{
unsigned int index = uniform(rng);
selectIndexs.insert(index);
if(selectIndexs.size() == minCnt)
{
cout << "selectIndexs.size() == minCnt: " << endl;
break;
}
}
// Step2: 进行模型参数估计 (y2 - y1)*x - (x2 - x1)*y + (y2 - y1)x2 - (x2 - x1)y2= 0
set<unsigned int>::iterator selectIter = selectIndexs.begin();
while(selectIter != selectIndexs.end())
{
unsigned int index = *selectIter;
selectPoints.push_back(pstData[index]);
selectIter++;
}
double deltaY = (selectPoints[1]).y - (selectPoints[0]).y;
double deltaX = (selectPoints[1]).x - (selectPoints[0]).x;
temp_A = deltaY;
temp_B = -deltaX;
temp_C = -deltaY * (selectPoints[1]).x + deltaX * (selectPoints[1]).y;
// Step3: 进行模型评估: 点到直线的距离
int dataIter = 0;
double meanError = 0;
set<unsigned int> tmpConsensusIndexs;
while(dataIter < dataCnt)
{
double distance =
(temp_A * pstData[dataIter].x + temp_B * pstData[dataIter].y + temp_C) / sqrt(temp_A*temp_A + temp_B*temp_B);
distance = distance > 0 ? distance : -distance;
if(distance <= maxErrorThreshold)
{
tmpConsensusIndexs.insert(dataIter);
}
meanError += distance;
dataIter++;
}
// Step4: 判断一致性: 满足一致性集合的最小元素个数条件 + 至少比上一次的好
if(tmpConsensusIndexs.size() >= consensusCntThreshold && tmpConsensusIndexs.size() >= bestConsensusCnt )
{
A = temp_A;
B = temp_B;
C = temp_C;
bestConsensusCnt = consensusIndexs.size(); // 更新一致性索引集合元素个数
modelMeanError = meanError / dataCnt;
consensusIndexs.clear();
consensusIndexs = tmpConsensusIndexs; // 更新一致性索引集合
isNonFind = 0;
cout << "bestConsensusCnt: " << bestConsensusCnt << endl;
cout << "tmpConsensusIndexs: " << tmpConsensusIndexs.size() << endl;
}
iter++;
}
return isNonFind;
}
参考链接:https://blog.youkuaiyun.com/hit1524468/article/details/80375495
参考链接代码有些小问题,本文中与参考链接代码稍有不同,ABC需要在函数外边定义,并将函数内ABC修改为temp_A、temp_B、temp_C,然后找到最优直线,将最优直线的参数temp_A、temp_B、temp_更新到ABC,以此来保存最优直线的参数。