训练数据data.txt中有112行数据,其中前12行是已有数据,后面100行是在完成算法后随机生成的。
这里KNN算法的实现是先求出所有的距离,下一篇则使用维护一个k长度队列的方式。
KNN类的声明
#include <iostream>
#include <fstream>
#include <map>
#include <vector>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#define _maxRow 112
#define _maxCol 2
using namespace std;
typedef pair<int,double> PAIR;
class KNN
{
public:
KNN();
bool printData(); //测试用函数
void getAllDistance(); //获取所有点到测试数据的距离
void getMaxFreLabel(); //获取范围内训练数据中出现频率最高的标签
void show(); //打印范围内所有训练数据的信息
private:
int k;
double data[_maxRow][_maxCol]; //用于存储训练数据的坐标
char labels[_maxRow]; //用于存储训练数据的标签
double testData[_maxCol]; //用于存储测试数据坐标
map<int,double> index_dis; //用于存储每个训练数据的距离
map<char,int> label_times; //存储每个标签的出现频率
bool readData(); //读取data.txt文档中的数据
double getDistance(double* d1,double* d2); //获取测试数据和所有训练数据间的距离
struct CmpByValue //提供pair<int,double>类的比较方法
{
bool operator() (const PAIR& lhs,const PAIR& rhs)
{
return lhs.second < rhs.second;
}
};
};
构造函数
KNN::KNN()
{
if(!readData()) //读取训练数据
exit(1);
cout<<"输入测试数据:"<<endl;
for(int i=0; i<_maxCol; i++)
cin>>testData[i];
cout<<endl;
cout<<"请输入k:";
cin>>k;
}
读入数据
bool KNN::readData()
{
ifstream fin;
fin.open("data.txt");
//打开文件
if(!fin)
{
cout<<"找不到文件!"<<endl;
return false;
}
for(int i=0; i<_maxRow; i++)
{
for(int j=0; j<_maxCol; j++)
{
fin>>data[i][j];
}
fin>>labels[i];
}
//读入数据
fin.close();
return true;
//关闭文件
}
获取测试数据和训练数据的距离
/*获取两点间距*/
double KNN::getDistance(double* d1,double* d2)
{
double sum=0;
for(int i=0; i<_maxCol; i++)
{
sum+=pow((d1[i]-d2[i]),2);
}
return sqrt(sum);
}
/*获取测试数据和所有训练数据的距离*/
void KNN::getAllDistance()
{
double distance;
for(int i=0; i<_maxRow; i++)
{
distance=getDistance(data[i],testData);
index_dis[i]=distance; //将距离存入对应的map
}
}
核心部分
这里是直接将所有距离进行排序,并选出距离最小的k个数据。
这里也可以维护大小为k的队列,先取出前k个数据存入队列,然后遍历距离L,若L>=Lmax则继续遍历,若L<Lmax,则用L替换Lmax,并对队列进行排序。
/*获取k个距离最近的训练数据中出现频率最高的标签*/
void KNN::getMaxFreLabel()
{
vector<pair<int,double> > vec_index_dis(index_dis.begin(),index_dis.end());//可尝试改为PAIR
sort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue()); //对所有距离排序
/*可尝试通过维护一个大小为k的队列,实现同样的功能*/
//cout<<"距离最小的"<<k<<"个数据为:";
for(int i=0;i<k;i++)
{
label_times[labels[vec_index_dis[i].first]]++;
/* vec_index_dis[i].first为队列中第i个数据的first,即对应训练数据的序号
labels[]为序号对应的标签
label_times[]++将标签对应的出现频率自增*/
}
map<char,int>::iterator itr=label_times.begin();
int maxFreq=0;
char testLabel;
while(itr!=label_times.end())
{
if(itr->second>maxFreq)
{
maxFreq=itr->second;
testLabel=itr->first;
}
itr++;
}
/*读取频率最高的标签并打印*/
cout<<"数据属于标签:"<<testLabel<<endl;
}
主函数
int main()
{
KNN knn;
knn.getAllDistance();
knn.getMaxFreLabel();
return 0;
}
参考资料:https://blog.youkuaiyun.com/lavorange/article/details/16924705