实现了KD平衡树的程序,由于MATLAB实现需要用到引用类型或者采用循环实现(见MATLAB的KDTreeSearcher.m),因此采用C#来实现
using System;
using System.Collections.Generic;
using System.Linq;
namespace KNNSearch
{
///
/// Description of KNN.
///
public class KNN
{
///
/// 叶子节点点的个数
///
private int leafnum = 30;
///
/// 待分类数据
///
private List rawData;
///
/// 生成原始数据
///
private void GeneralRawData()
{
if (rawData == null)
{
rawData = new List();
}
else
{
rawData.Clear();
}
Random r = new Random();
for (int i = 0; i < 500; i++)
{
rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble() });
}
}
///
/// 创建KD树
///
///
///
private Node CreateKDTree(List data)
{
// 创建根节点
Node root = new Node();
// 添加当前节点数据
root.nodeData = data;
// 如果节点的数据数量小于叶子节点的数量限制,则当前节点为叶子节点
if (data.Count <= leafnum)
{
root.leftNode = null;
root.rightNode = null;
root.point = double.NaN;
root.splitaxis = -1;
return root;
}
// 找到分割轴
int splitAxis = GetSplitAxis(data);
// 分割数据
Tuple, List> dataSplit = GetSplitNum(data, splitAxis);
root.splitaxis = splitAxis;
root.point = dataSplit.Item1;
root.leftNode = CreateKDTree(dataSplit.Item2);
root.rightNode = CreateKDTree(dataSplit.Item3);
return root;
}
private Tuple, List> GetSplitNum(List data, int splitAxis)
{
// 对数据按照第splitAxis排序
var data0 = splitAxis == 0 ? (data.OrderBy(x => x.X)).ToList() :
(splitAxis == 1 ? (data.OrderBy(x => x.Y)).ToList() :
(data.OrderBy(x => x.Z)).ToList());
int half = data0.Count / 2;
List leftdata = new List();
List rightdata = new List();
for (int i = 0; i < data0.Count; i++)
{
if (i <= half)
{
leftdata.Add(data0[i]);
}
else
{
rightdata.Add(data0[i]);
}
}
double splitnum = splitAxis == 0 ? data[half].X + data[half + 1].X :
(splitAxis == 1 ? data[half].Y + data[half + 1].Y :
data[half].Z + data[half + 1].Z);
return new Tuple, List>(splitnum / 2, leftdata, rightdata);
}
///
/// 获取分割轴编号
///
///
///
private int GetSplitAxis(List data)
{
// 设定数据范围最大的轴作为分割轴(也有其他的方式,如方差,或者轮流的方式)
var xData = data.Select(item => item.X);
var yData = data.Select(item => item.Y);
var zData = data.Select(item => item.Z);
List ranges = new List();
ranges.Add(xData.Max() - xData.Min());
ranges.Add(yData.Max() - yData.Min());
ranges.Add(zData.Max() - zData.Min());
var sorted = ranges.Select((x, i) => new KeyValuePair(x, i)).OrderByDescending(x => x.Key).ToList();
return sorted.Select(x => x.Value).ToList()[0]; ;
}
public KNN()
{
GeneralRawData();
Node node = CreateKDTree(rawData);
}
}
///
/// Description of Node.
///
public class Node
{
///
/// 切分的阈值点
///
public double point;
///
/// 左节点
///
public Node leftNode;
///
/// 右节点
///
public Node rightNode;
///
/// 节点包含的数据
///
public List nodeData;
///
/// 分割轴
///
public int splitaxis;
public Node()
{
}
}
public class Point
{
public double X;
public double Y;
public double Z;
}
}