KNN算法的C#代码,上一篇博客中的
C#创建KD树的程序中的算法是模仿MATLAB的KDTree的程序思路
这次按照李航老师的《统计学习方法》中的思路,写一个C#程序,其中创建KD树的分割的维度并不是轮寻,而是按照数据的范围来找的
using System;
using System.Collections.Generic;
using System.Linq;
namespace KNNSearch
{
///
/// Description of KNN.
///
public class Knn
{
///
/// 叶子节点点的个数
///
private int leafnum = 1;
///
/// 节点名称集合
///
private List
_nodeNames = new List
{
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z"
};
private List
GeneralRawData(int num)
{
List
rawData = new List
();
Random r = new Random(1);
for (var i = 0; i < num; i++)
{
rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble(), ID = i });
}
//PrintListData(rawData);
return rawData;
}
///
/// 创建KD树 ///
///
///
private Node CreateKdTree(List
data) { // 创建根节点 Node root = new Node {NodeData = data}; // 添加当前节点数据 // 如果节点的数据数量小于叶子节点的数量限制,则当前节点为叶子节点 if (data.Count <= leafnum) { if (data.Count == 0) { return null; } root.LeftNode = null; root.RightNode = null; root.Point = data[0]; root.Splitaxis = -1; root.Name = "AA"; //_nodeNames.RemoveAt(0); //Console.WriteLine("叶子节点编号{0}, 数据点编号{1}",root.Name, root.NodeData[0].ID); return root; } // 找到分割轴 int splitAxis = GetSplitAxis(data); // 分割数据 Tuple
, List
> dataSplit = GetSplitNum(data, splitAxis); root.Splitaxis = splitAxis; root.Point = dataSplit.Item1; root.Name = "AA"; //_nodeNames.RemoveAt(0); root.LeftNode = CreateKdTree(dataSplit.Item2); root.RightNode = CreateKdTree(dataSplit.Item3); return root; } private Tuple
, List
> GetSplitNum(List
data, int splitAxis) { // 对数据按照第splitAxis排序 var data0 = data.OrderBy(x => Dict[splitAxis](x)).ToList(); int half = data0.Count / 2; List
leftdata = new List
(); List
rightdata = new List
(); for (int i = 0; i < data0.Count; i++) { if (i < half) { leftdata.Add(data0[i]); } else if (i > half) { rightdata.Add(data0[i]); } } //Console.WriteLine("Split Axis: {0}", splitAxis); //PrintListData(data0); return new Tuple
, List
>(data0[half], leftdata, rightdata); } ///
/// 获取分割轴编号 /// ///
///
private int GetSplitAxis(List
data) { // 设定数据范围最大的轴作为分割轴(也有其他的方式,如方差,或者轮流的方式) List
ranges = new List
(); for (int i = 0; i < 3; i++) { var i1 = i; var xxxData = data.Select(item => Dict[i1](item)); var enumerable = xxxData as double[] ?? xxxData.ToArray(); ranges.Add(enumerable.Max() - enumerable.Min()); } var sorted = ranges.Select((x, i) => new KeyValuePair
(x, i)).OrderByDescending(x => x.Key).ToList(); return sorted.Select(x => x.Value).ToList()[0]; } ///
/// KNN搜索 /// ///
///
///
private Node KdTreeFindNearest(Node tree, Point target) { // 搜索路径 List
searchPath = new List
(); // 当前搜索点 Node searchNode = tree; //(1) 从根节点开始往下搜索, 递归的向下访问KD树 while (searchNode != null) { // 添加当前节点到搜索路径 searchPath.Add(searchNode); var splitAxis = searchNode.Splitaxis; // 若目标点当前维小于节点的阈值,移动至左叶子点,否则移动至右叶子点 searchNode = splitAxis < 0 ? null : Dict[splitAxis](target) <= Dict[splitAxis](searchNode.Point) ? searchNode.LeftNode : searchNode.RightNode; } // (2) 以此节点为当前最近节点 // 最近的点 Node nearestPoint = searchPath[searchPath.Count - 1]; // 初值最短距离 double dist = NearestDist(nearestPoint.NodeData, target); // 移除当前点 searchPath.Remove(nearestPoint); // (3). 递归向上回退 while (searchPath.Count > 0) { var backNode = searchPath[searchPath.Count - 1]; // 回退节点 //(a)如果该节点保存的实例点距离目标点的距离比当前最近点更近, 则该点设置为当前最近点 if (dist > NearestDist(backNode.NodeData, target)) { nearestPoint = backNode; dist = NearestDist(backNode.NodeData, target); // 如果更近,说明必然在其子节点中 var splitaxis = backNode.Splitaxis; // 目标点据当前分割边界的距离 var distTargetToBound = Math.Abs(Dict[splitaxis](target) - Dict[splitaxis](backNode.Point)); // 如果以最近距离为半径,另外一个子节点位于球的内部,说明最近点位于另外一个叶子节点 // 移动至另外一个节点 if (distTargetToBound < dist) { // 当前点位于位于该节点的左子节点,需要进入另外一个节点搜索 searchNode = Dict[splitaxis](target) < Dict[splitaxis](backNode.Point) ? backNode.RightNode : backNode.LeftNode; searchPath.Add(searchNode); } } searchPath.Remove(backNode); } return nearestPoint; } private static Dictionary
> Dict => new Dictionary
> { { 0, p => p.X }, { 1, p => p.Y }, { 2, p => p.Z }, }; public List
NodeNames { get => _nodeNames; set => _nodeNames = value; } ///
/// 计算当前结点实例点距目标点的最近距离 /// ///
///
///
private double NearestDist(List
nodeData, Point target) { List
ss = nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) + Math.Pow(item.Y - target.Y, 2) + Math.Pow(item.Z - target.Z, 2))) .ToList(); return nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) + Math.Pow(item.Y - target.Y, 2) + Math.Pow(item.Z - target.Z, 2))).ToList().Min(); } private void PrintListData(List
data) { Console.WriteLine("****************"); foreach (Point point in data) { Console.WriteLine(point); } } public Knn() { List
rawData = GeneralRawData(180); Node node = CreateKdTree(rawData); Point target = new Point() {X = 0.5, Y = 0.5, Z = 0.5}; Node nd = KdTreeFindNearest(node, target); // 最短距离为 double nearestDistFromKnn = NearestDist(nd.NodeData, target); Console.WriteLine("通过KNN搜索计算得到的最短距离为{0:F3}", nearestDistFromKnn); double nearestDistFromLoop = NearestDist(rawData, target); Console.WriteLine("通过KNN遍历计算得到的最短距离为{0:F3}", nearestDistFromLoop); } } ///
/// Description of Node. /// public class Node { ///
/// 节点名称 /// public string Name; ///
/// 切分的阈值点 /// public Point Point; ///
/// 左节点 /// public Node LeftNode; ///
/// 右节点 /// public Node RightNode; ///
/// 节点包含的数据 /// public List
NodeData; ///
/// 分割轴 /// public int Splitaxis; } public class Point { public double X; public double Y; public double Z; public int ID; // debug用 public override string ToString() { return $"({X},{Y},{Z},{ID})"; } } }