题目如图。
m=4,即点均为四维空间的点。
n数目不定,可以理解为几万,几十万甚至上千万。
使用spark计算。资源配置为:executor-cores:6,executor-memory:10G。
解法一:
首先将点的矩阵弄成dataframe(dataframe里每一个Row的内容均为:[uuid,double1,double2,double3,double4])
然后dataframe join自身,然后再map求每一行(每点)与其他点的距离,并返回JavaPairRDD,key为点的uuid,value为欧几里得距离。
然后再groupByKey,这样得到的JavaPairRDD的key为点的uuid,value为iteratorable即是该点与其他点距离的数组。
然后再map,sort每个点的距离iterable,从而得到每个点最近的第k个点与它的距离。
代码如下:
性能:n=10000的情况下,大概需要运行3分钟
瓶颈分析:
1. 计算复杂度为n^2,10000的情况下达到亿级别
2. uuid为字符串,增加了分布式shuffle时移动的数据量
3. 欧几里得距离需要乘方和开发,浮点数计算比较消耗性能
解法2:
1. 将欧几里得距离换为曼哈顿距离
2. uuid使用zipWithUniqueId生成,是long
3. 不使用join,使用RDD的cartesian方法生成
性能:与解法一差异不大
综上可知,主要还是求每个点与其他点的距离(即解法一的join,解法二的cartesian,非常耗时)。
如果要在本质上解决问题,需要将【求每个点与其他点的距离】这一步给剪掉。
k-d树是一种二叉树,它主要是循环的根据curDimension%DimensionNum进行分割子树。
而搜索邻近k个点时,只需要维护一个优先队列二叉式的搜索即可。即省略了计算每个点与其他点的距离这一部分。
代码如下:
HPoint.java
public class HPoint implements Serializable {
protected double[] coord;
protected HPoint(int n) {
coord = new double[n];
}
protected HPoint(double[] x) {
coord = new double[x.length];
for (int i = 0; i < x.length; ++i) {
coord[i] = x[i];
}
}
protected Object clone() {
return new HPoint(coord);
}
protected boolean equals(HPoint p) {
for (int i = 0; i < coord.length; ++i) {
if (coord[i] != p.coord[i]) {
return false;
}
}
return true;
}
//曼哈顿距离
protected static double manhanttandist(HPoint x,HPoint y){
double dist = 0;
for(int i=0; i < x.coord.length; ++i){
dist += Math.abs(x.coord[i] - y.coord[i]);
}
return dist;
}
//平方距离
protected static double sqrdist(HPoint x, HPoint y) {
double dist = 0;
for (int i = 0; i < x.coord.length; ++i) {
double diff = (x.coord[i] - y.coord[i]);
dist += (diff * diff);
}
return dist;
}
//欧几里得距离
protected static double eucdist(HPoint x, HPoint y) {
return Math.sqrt(sqrdist(x, y));
}
public String toString() {
String s = "";
for (int i = 0; i < coord.length; ++i) {
s = s + coord[i] + " ";
}
return s;
}
}
HRect.java
public class HRect implements Serializable {
protected HPoint min;
protected HPoint max;
protected HRect(int ndims) {
min = new HPoint(ndims);
max = new HPoint(ndims);
}
protected HRect(HPoint vmin, HPoint vmax) {
min = (HPoint) vmin.clone();
max = (HPoint) vmax.clone();
}
protected Object clone() {
return new HRect(min, max);
}
//返回区域里距离HPoint距离最近的点
protected HPoint closest(HPoint t) {
HPoint p = new HPoint(t.coord.length);
for (int i = 0; i < t.coord.length; ++i) {
if (t.coord[i] <= min.coord[i]) {
p.coord[i] = min.coord[i];
} else if (t.coord[i] >= max.coord[i]) {
p.coord[i] = max.coord[i];
} else {
p.coord[i] = t.coord[i];
}
}
return p;
}
//初始化d维度的区域
protected static HRect infiniteHRect(int d) {
HPoint vmin = new HPoint(d);
HPoint vmax = new HPoint(d);
for (int i = 0; i < d; ++i) {
vmin.coord[i] = Double.NEGATIVE_INFINITY;
vmax.coord[i] = Double.POSITIVE_INFINITY;
}
return new HRect(vmin, vmax);
}
public String toString() {
return min + "\n" + max + "\n";
}
}
KDNode.java
//kd树节点
public class KDNode<T> implements Serializable {
protected HPoint k;
protected KDNode left, right;
protected boolean deleted;
T v;
private KDNode(HPoint key, T val) {
k = key;
v = val;
left = null;
right = null;
deleted = false;
}
//插入节点
protected static KDNode ins(HPoint key, Object val, KDNode t, int lev, int K) {
if (t == null) {
t = new KDNode(key, val);
} else if (key.equals(t.k)) {
//插入的值与该节点重复;如果该节点被标记为已删除,则将此节点恢复为未删除状态
if (t.deleted) {
t.deleted = false;
t.v = val;
}
} else if (key.coord[lev] > t.k.coord[lev]) {
t.right = ins(key, val, t.right, (lev + 1) % K, K);
} else {
t.left = ins(key, val, t.left, (lev + 1) % K, K);
}
return t;
}
//搜索节点
protected static KDNode srch(HPoint key, KDNode t, int K) {
for (int lev = 0; t != null; lev = (lev + 1) % K) {
if (!t.deleted && key.equals(t.k)) {
return t;
} else if (key.coord[lev] > t.k.coord[lev]) {
t = t.right;
} else {
t = t.left;
}
}
return null;
}
protected static void rsearch(HPoint lowk, HPoint uppk, KDNode t, int lev, int K, Vector<KDNode> v) {
if (t == null) {
return;
}
if (lowk.coord[lev] <= t.k.coord[lev]) {
rsearch(lowk, uppk, t.left, (lev + 1) % K, K, v);
}
int j;
for (j = 0; j < K && lowk.coord[j] <= t.k.coord[j] && uppk.coord[j] >= t.k.coord[j]; j++)
;
if (j == K) {
v.add(t);
}
if (uppk.coord[lev] > t.k.coord[lev]) {
rsearch(lowk, uppk, t.right, (lev + 1) % K, K, v);
}
}
//近邻搜索
protected static void nnbr(KDNode kd, HPoint target, HRect hr, double max_dist_sqd, int lev, int K,
NearestNeighborList nnl) {
if (kd == null) {
return;
}
int s = lev % K;
HPoint pivot = kd.k;
double pivot_to_target = HPoint.manhanttandist(pivot, target);
HRect left_hr = hr;
HRect right_hr = (HRect) hr.clone();
left_hr.max.coord[s] = pivot.coord[s];
right_hr.min.coord[s] = pivot.coord[s];
boolean target_in_left = target.coord[s] < pivot.coord[s];
KDNode nearer_kd;
HRect nearer_hr;
KDNode further_kd;
HRect further_hr;
if (target_in_left) {
nearer_kd = kd.left;
nearer_hr = left_hr;
further_kd = kd.right;
further_hr = right_hr;
}
else {
nearer_kd = kd.right;
nearer_hr = right_hr;
further_kd = kd.left;
further_hr = left_hr;
}
nnbr(nearer_kd, target, nearer_hr, max_dist_sqd, lev + 1, K, nnl);
KDNode nearest = (KDNode) nnl.getHighest();
double dist_sqd;
if (!nnl.isCapacityReached()) {
dist_sqd = Double.MAX_VALUE;
} else {
dist_sqd = nnl.getMaxPriority();
}
max_dist_sqd = Math.min(max_dist_sqd, dist_sqd);
HPoint closest = further_hr.closest(target);
if (Double.valueOf(HPoint.manhanttandist(closest, target)).compareTo(max_dist_sqd) < 0) {
if (pivot_to_target < dist_sqd) {
nearest = kd;
dist_sqd = pivot_to_target;
if (!kd.deleted) {
nnl.insert(kd, dist_sqd);
}
if (nnl.isCapacityReached()) {
max_dist_sqd = nnl.getMaxPriority();
} else {
max_dist_sqd = Double.MAX_VALUE;
}
}
nnbr(further_kd, target, further_hr, max_dist_sqd, lev + 1, K, nnl);
KDNode temp_nearest = (KDNode) nnl.getHighest();
double temp_dist_sqd = nnl.getMaxPriority();
if (temp_dist_sqd < dist_sqd) {
nearest = temp_nearest;
dist_sqd = temp_dist_sqd;
}
}
else if (pivot_to_target < max_dist_sqd) {
nearest = kd;
dist_sqd = pivot_to_target;
}
}
private static String pad(int n) {
String s = "";
for (int i = 0; i < n; ++i) {
s += " ";
}
return s;
}
private static void hrcopy(HRect hr_src, HRect hr_dst) {
hpcopy(hr_src.min, hr_dst.min);
hpcopy(hr_src.max, hr_dst.max);
}
private static void hpcopy(HPoint hp_src, HPoint hp_dst) {
for (int i = 0; i < hp_dst.coord.length; ++i) {
hp_dst.coord[i] = hp_src.coord[i];
}
}
protected String toString(int depth) {
String s = k + " " + v + (deleted ? "*" : "");
if (left != null) {
s = s + "\n" + pad(depth) + "L " + left.toString(depth + 1);
}
if (right != null) {
s = s + "\n" + pad(depth) + "R " + right.toString(depth + 1);
}
return s;
}
}
KDTree.java
/**
* kd树
*/
public class KDTree<T> implements java.io.Serializable {
//维度
private int m_K;
//根节点
private KDNode m_root;
//树里的节点个数
private int m_count;
//创建一个k维的k-d树
public KDTree(int k) {
m_K = k;
m_root = null;
}
//向kd树里插入一个节点
//key是k维的值
//value是节点的标签
public void insert(double[] key, T value) {
if (key.length != m_K) {
throw new RuntimeException("KDTree: wrong key size!");
} else {
m_root = KDNode.ins(new HPoint(key), value, m_root, 0, m_K);
}
m_count++;
}
//根据key数值,搜索kd树节点
public Object search(double[] key) {
if (key.length != m_K) {
throw new RuntimeException("KDTree: wrong key size!");
}
KDNode kd = KDNode.srch(new HPoint(key), m_root, m_K);
return (kd == null ? null : kd.v);
}
//删除kd树节点
public void delete(double[] key) {
if (key.length != m_K) {
throw new RuntimeException("KDTree: wrong key size!");
} else {
KDNode t = KDNode.srch(new HPoint(key), m_root, m_K);
if (t == null) {
throw new RuntimeException("KDTree: key missing!");
} else {
t.deleted = true;
}
m_count--;
}
}
//搜索距离最近的kd树节点
public T nearest(double[] key) {
List<T> nbrs = nearest(key, 1);
return nbrs.get(0);
}
//搜索最近的n个kd树节点
public List<T> nearest(double[] key, int n) {
if (n < 0 || n > m_count) {
throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot"
+ " be negative or greater than number of nodes (" + m_count + ").");
}
if (key.length != m_K) {
throw new RuntimeException("KDTree: wrong key size!");
}
List<T> nbrs = new ArrayList<T>(n);
NearestNeighborList nnl = new NearestNeighborList(n);
HRect hr = HRect.infiniteHRect(key.length);
double max_dist_sqd = Double.MAX_VALUE;
HPoint keyp = new HPoint(key);
KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl);
for (int i = 0; i < n; ++i) {
KDNode<T> kd = (KDNode) nnl.removeHighest();
nbrs.add(kd.v);
}
return nbrs;
}
private double mandist(double[] p1,double[] p2){
double dist = 0.0;
for(int i=0;i<p1.length;i++){
dist += Math.abs(p1[i]-p2[i]);
}
return dist;
}
//搜索最近的n个kd树节点,返回与他们的的距离
public List<Double> nearestDistance(double[] key, int n) {
if (n < 0 || n > m_count) {
throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot"
+ " be negative or greater than number of nodes (" + m_count + ").");
}
if (key.length != m_K) {
throw new RuntimeException("KDTree: wrong key size!");
}
List<Double> nbrs = new ArrayList<Double>(n);
NearestNeighborList nnl = new NearestNeighborList(n);
HRect hr = HRect.infiniteHRect(key.length);
double max_dist_sqd = Double.MAX_VALUE;
HPoint keyp = new HPoint(key);
KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl);
for (int i = 0; i < n; ++i) {
KDNode<T> kd = (KDNode) nnl.removeHighest();
nbrs.add(mandist(kd.k.coord,key));
}
return nbrs;
}
public String toString() {
return m_root.toString(0);
}
}
NeareastNeighbor.java
/**
* 最近邻居列表,基于优先队列实现
*/
public class NearestNeighborList implements Serializable {
public static int REMOVE_HIGHEST = 1;
public static int REMOVE_LOWEST = 2;
PriorityQueue m_Queue = null;
int m_Capacity = 0;
//只保存最近的capacity个邻居
public NearestNeighborList(int capacity) {
m_Capacity = capacity;
m_Queue = new PriorityQueue(m_Capacity, Double.POSITIVE_INFINITY);
}
public double getMaxPriority() {
if (m_Queue.length() == 0) {
return Double.POSITIVE_INFINITY;
}
return m_Queue.getMaxPriority();
}
public boolean insert(Object object, double priority) {
if (m_Queue.length() < m_Capacity) {
//如果尚未达到capacity个,则直接放入队列
m_Queue.add(object, priority);
return true;
}
if (priority > m_Queue.getMaxPriority()) {
//如果优先级比队列里的其他元素都大,则入不了队列
return false;
}
//移除队列中优先级最大的元素,即队尾元素
m_Queue.remove();
//将新元素插入
m_Queue.add(object, priority);
return true;
}
public boolean isCapacityReached() {
return m_Queue.length() >= m_Capacity;
}
public Object getHighest() {
return m_Queue.front();
}
public boolean isEmpty() {
return m_Queue.length() == 0;
}
public int getSize() {
return m_Queue.length();
}
public Object removeHighest() {
return m_Queue.remove();
}
}
PriorityQueue.java
/**
* 优先队列,优先级越低的越在队列前
*/
public class PriorityQueue implements Serializable {
private double maxPriority = Double.MAX_VALUE;
private Object[] data;
private double[] value;
private int count;
private int capacity;
public PriorityQueue() {
init(20);
}
public PriorityQueue(int capacity) {
init(capacity);
}
public PriorityQueue(int capacity, double maxPriority) {
this.maxPriority = maxPriority;
init(capacity);
}
private void init(int size) {
capacity = size;
data = new Object[capacity + 1];
value = new double[capacity + 1];
value[0] = maxPriority;
data[0] = null;
}
public void add(Object element, double priority) {
if (count++ >= capacity) {
expandCapacity();
}
value[count] = priority;
data[count] = element;
bubbleUp(count);
}
public Object remove() {
if (count == 0) {
return null;
}
Object element = data[1];
data[1] = data[count];
value[1] = value[count];
data[count] = null;
value[count] = 0L;
count--;
bubbleDown(1);
return element;
}
public Object front() {
return data[1];
}
public double getMaxPriority() {
return value[1];
}
private void bubbleDown(int pos) {
Object element = data[pos];
double priority = value[pos];
int child;
for (; pos * 2 <= count; pos = child) {
child = pos * 2;
if (child != count) {
if (value[child] < value[child + 1]) {
child++;
}
}
if (priority < value[child]) {
value[pos] = value[child];
data[pos] = data[child];
} else {
break;
}
}
value[pos] = priority;
data[pos] = element;
}
private void bubbleUp(int pos) {
Object element = data[pos];
double priority = value[pos];
while (value[pos / 2] < priority) {
value[pos] = value[pos / 2];
data[pos] = data[pos / 2];
pos /= 2;
}
value[pos] = priority;
data[pos] = element;
}
private void expandCapacity() {
capacity = count * 2;
Object[] elements = new Object[capacity + 1];
double[] prioritys = new double[capacity + 1];
System.arraycopy(data, 0, elements, 0, data.length);
System.arraycopy(value, 0, prioritys, 0, data.length);
data = elements;
value = prioritys;
}
public void clear() {
for (int i = 1; i < count; i++) {
data[i] = null;
}
count = 0;
}
public int length() {
return count;
}
}
我们简单的写个程序测一下性能,测试代码如下:
/**
* KD树k近邻搜索
*/
public class KDTreeTest {
public double mandist(Double[] p1,Double[] p2){
double dist = 0.0;
for(int i=0;i<p1.length;i++){
dist += Math.abs(p1[i]-p2[i]);
}
return dist;
}
public void test(){
List<Double> list1 = new ArrayList<Double>();
List<Double> list2 = new ArrayList<Double>();
int k=100;
System.out.println("intput size:"+k);
int dimension = 4;
//初始化100个4维变量
List<Double[]> list = new ArrayList<Double[]>();
for(int i=0;i<k;i++){
Double[] arr = new Double[dimension];
for(int j=0;j<4;j++){
arr[j] = Math.random();
}
list.add(arr);
}
System.out.println("****************kdtree********************");
long time1 = System.currentTimeMillis();
//计算每个点最近的第3个点
KDTree<Integer> kdTree = new KDTree<Integer>(dimension);
for(int i=0;i<k;i++){
double[] curDouble = new double[dimension];
int index = -1;
for(Double item:list.get(i)){
curDouble[++index] = item;
}
kdTree.insert(curDouble,i);
}
int nearest = 8;
for(int i=0;i<k;i++){
double[] curDouble = new double[dimension];
int index = -1;
for(Double item:list.get(i)){
curDouble[++index] = item;
}
List<Double> distance = kdTree.nearestDistance(curDouble, nearest + 1);
Collections.sort(distance);
//System.out.println(distance.get(distance.size() - 1));
list1.add(distance.get(distance.size()-1));
}
long time2 = System.currentTimeMillis();
System.out.println((time2-time1)+"ms");
System.out.println("****************kdtree********************");
System.out.println("****************normal********************");
long time3 = System.currentTimeMillis();
//计算每个点最近的第3个点
for(int i=0;i<k;i++){
List<Double> distance = new ArrayList<Double>();
for(int j=0;j<k;j++){
distance.add(mandist(list.get(i), list.get(j)));
}
Collections.sort(distance);
//System.out.println(distance.get(nearest));
list2.add(distance.get(nearest));
}
long time4 = System.currentTimeMillis();
System.out.println((time4-time3)+"ms");
System.out.println("****************normal********************");
boolean same = true;
for(int i=0;i<list1.size();i++){
if(!list1.get(i).equals(list2.get(i))){
same = false;
break;
}
}
if(same){
System.out.println("result is same");
}else{
System.out.println("result not same");
}
}
public static void main(String[] args){
new KDTreeTest().test();
}
}
各个数据量级别的比对如下(kdtree和硬查的方法):
10万级别时,普通硬查方法已经几分钟无法算出结果了。
60万级别时,我们单看kdtree:
那么如何用在spark程序里呢?用法如下:
double eps = Double.MAX_VALUE;
final KDTree<Integer> kdtree = new KDTree<Integer>(4);
scala.reflect.ClassTag<KDTree<Integer>> curClassTag = scala.reflect.ClassTag$.MODULE$.apply(KDTree.class);
try{
JavaRDD<Row> rowRdd = df.select("fwd_ppf", "fwd_bpp", "recv_ppf", "recv_bpp").toJavaRDD();
List<Row> list = rowRdd.collect();
for(int i=0;i<list.size();i++){
kdtree.insert(new double[]{
Double.valueOf(String.valueOf(list.get(i).get(0))),
Double.valueOf(String.valueOf(list.get(i).get(1))),
Double.valueOf(String.valueOf(list.get(i).get(2))),
Double.valueOf(String.valueOf(list.get(i).get(3)))
},i);
}
//生成kdtree广播变量
final Broadcast<KDTree<Integer>> broadCast = sqlContext.sparkContext().broadcast(kdtree,curClassTag);
JavaRDD<Double> sortRDD = rowRdd.map(new Function<Row, Double>() {
public Double call(Row row) throws Exception {
double distance = Double.MAX_VALUE;
try {
//找到第minPts近的节点
List<Double> list = broadCast.getValue().nearestDistance(new double[]{
Double.valueOf(String.valueOf(row.get(0))),
Double.valueOf(String.valueOf(row.get(1))),
Double.valueOf(String.valueOf(row.get(2))),
Double.valueOf(String.valueOf(row.get(3)))
}, minPts + 1);
Collections.sort(list);
distance = list.get(list.size() - 1);
} catch (Exception ex) {
logger.error("", ex);
}
return distance;
}
});
sortRDD.persist(StorageLevel.MEMORY_AND_DISK());
spark运行结果如下:
目前的弊端为,构造树时需要collect数据,我们算一下driver需要承担的数据量:
4个double数据,6000万情况下,所占字节数:
所以driver端的内存最好配得多一些,我觉得4G比较保险。