m维空间里n个点每点最近的第k个点的距离

题目如图。

m=4,即点均为四维空间的点。

n数目不定,可以理解为几万,几十万甚至上千万。

使用spark计算。资源配置为:executor-cores:6,executor-memory:10G。

 

解法一:

首先将点的矩阵弄成dataframe(dataframe里每一个Row的内容均为:[uuid,double1,double2,double3,double4])

然后dataframe join自身,然后再map求每一行(每点)与其他点的距离,并返回JavaPairRDD,key为点的uuid,value为欧几里得距离。

然后再groupByKey,这样得到的JavaPairRDD的key为点的uuid,value为iteratorable即是该点与其他点距离的数组。

然后再map,sort每个点的距离iterable,从而得到每个点最近的第k个点与它的距离。

代码如下:


性能:n=10000的情况下,大概需要运行3分钟

瓶颈分析:

1.      计算复杂度为n^2,10000的情况下达到亿级别

2.      uuid为字符串,增加了分布式shuffle时移动的数据量

3.      欧几里得距离需要乘方和开发,浮点数计算比较消耗性能


解法2:

1.      将欧几里得距离换为曼哈顿距离

2.      uuid使用zipWithUniqueId生成,是long

3.      不使用join,使用RDD的cartesian方法生成

性能:与解法一差异不大



综上可知,主要还是求每个点与其他点的距离(即解法一的join,解法二的cartesian,非常耗时)。

如果要在本质上解决问题,需要将【求每个点与其他点的距离】这一步给剪掉。


k-d树是一种二叉树,它主要是循环的根据curDimension%DimensionNum进行分割子树。

而搜索邻近k个点时,只需要维护一个优先队列二叉式的搜索即可。即省略了计算每个点与其他点的距离这一部分。

代码如下:

HPoint.java

public class HPoint implements Serializable {
    protected double[] coord;

    protected HPoint(int n) {
        coord = new double[n];
    }

    protected HPoint(double[] x) {
        coord = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            coord[i] = x[i];
        }
    }

    protected Object clone() {
        return new HPoint(coord);
    }

    protected boolean equals(HPoint p) {
        for (int i = 0; i < coord.length; ++i) {
            if (coord[i] != p.coord[i]) {
                return false;
            }
        }
        return true;
    }

    //曼哈顿距离
    protected static double manhanttandist(HPoint x,HPoint y){
        double dist = 0;
        for(int i=0; i < x.coord.length; ++i){
            dist += Math.abs(x.coord[i] - y.coord[i]);
        }
        return dist;
    }

    //平方距离
    protected static double sqrdist(HPoint x, HPoint y) {
        double dist = 0;
        for (int i = 0; i < x.coord.length; ++i) {
            double diff = (x.coord[i] - y.coord[i]);
            dist += (diff * diff);
        }
        return dist;
    }

    //欧几里得距离
    protected static double eucdist(HPoint x, HPoint y) {
        return Math.sqrt(sqrdist(x, y));
    }

    public String toString() {
        String s = "";
        for (int i = 0; i < coord.length; ++i) {
            s = s + coord[i] + " ";
        }
        return s;
    }
}

HRect.java

public class HRect implements Serializable {
    protected HPoint min;
    protected HPoint max;

    protected HRect(int ndims) {
        min = new HPoint(ndims);
        max = new HPoint(ndims);
    }

    protected HRect(HPoint vmin, HPoint vmax) {
        min = (HPoint) vmin.clone();
        max = (HPoint) vmax.clone();
    }

    protected Object clone() {
        return new HRect(min, max);
    }

    //返回区域里距离HPoint距离最近的点
    protected HPoint closest(HPoint t) {
        HPoint p = new HPoint(t.coord.length);

        for (int i = 0; i < t.coord.length; ++i) {
            if (t.coord[i] <= min.coord[i]) {
                p.coord[i] = min.coord[i];
            } else if (t.coord[i] >= max.coord[i]) {
                p.coord[i] = max.coord[i];
            } else {
                p.coord[i] = t.coord[i];
            }
        }

        return p;
    }

    //初始化d维度的区域
    protected static HRect infiniteHRect(int d) {
        HPoint vmin = new HPoint(d);
        HPoint vmax = new HPoint(d);

        for (int i = 0; i < d; ++i) {
            vmin.coord[i] = Double.NEGATIVE_INFINITY;
            vmax.coord[i] = Double.POSITIVE_INFINITY;
        }

        return new HRect(vmin, vmax);
    }

    public String toString() {
        return min + "\n" + max + "\n";
    }
}

KDNode.java

//kd树节点
public class KDNode<T> implements Serializable {
    protected HPoint k;
    protected KDNode left, right;
    protected boolean deleted;
    T v;

    private KDNode(HPoint key, T val) {
        k = key;
        v = val;
        left = null;
        right = null;
        deleted = false;
    }

    //插入节点
    protected static KDNode ins(HPoint key, Object val, KDNode t, int lev, int K) {
        if (t == null) {
            t = new KDNode(key, val);
        } else if (key.equals(t.k)) {
            //插入的值与该节点重复;如果该节点被标记为已删除,则将此节点恢复为未删除状态
            if (t.deleted) {
                t.deleted = false;
                t.v = val;
            }
        } else if (key.coord[lev] > t.k.coord[lev]) {
            t.right = ins(key, val, t.right, (lev + 1) % K, K);
        } else {
            t.left = ins(key, val, t.left, (lev + 1) % K, K);
        }

        return t;
    }

    //搜索节点
    protected static KDNode srch(HPoint key, KDNode t, int K) {
        for (int lev = 0; t != null; lev = (lev + 1) % K) {
            if (!t.deleted && key.equals(t.k)) {
                return t;
            } else if (key.coord[lev] > t.k.coord[lev]) {
                t = t.right;
            } else {
                t = t.left;
            }
        }

        return null;
    }

    protected static void rsearch(HPoint lowk, HPoint uppk, KDNode t, int lev, int K, Vector<KDNode> v) {
        if (t == null) {
            return;
        }
        if (lowk.coord[lev] <= t.k.coord[lev]) {
            rsearch(lowk, uppk, t.left, (lev + 1) % K, K, v);
        }
        int j;
        for (j = 0; j < K && lowk.coord[j] <= t.k.coord[j] && uppk.coord[j] >= t.k.coord[j]; j++)
            ;
        if (j == K) {
            v.add(t);
        }
        if (uppk.coord[lev] > t.k.coord[lev]) {
            rsearch(lowk, uppk, t.right, (lev + 1) % K, K, v);
        }
    }

    //近邻搜索
    protected static void nnbr(KDNode kd, HPoint target, HRect hr, double max_dist_sqd, int lev, int K,
                               NearestNeighborList nnl) {
        if (kd == null) {
            return;
        }
        int s = lev % K;

        HPoint pivot = kd.k;
        double pivot_to_target = HPoint.manhanttandist(pivot, target);

        HRect left_hr = hr;
        HRect right_hr = (HRect) hr.clone();
        left_hr.max.coord[s] = pivot.coord[s];
        right_hr.min.coord[s] = pivot.coord[s];

        boolean target_in_left = target.coord[s] < pivot.coord[s];

        KDNode nearer_kd;
        HRect nearer_hr;
        KDNode further_kd;
        HRect further_hr;

        if (target_in_left) {
            nearer_kd = kd.left;
            nearer_hr = left_hr;
            further_kd = kd.right;
            further_hr = right_hr;
        }
        else {
            nearer_kd = kd.right;
            nearer_hr = right_hr;
            further_kd = kd.left;
            further_hr = left_hr;
        }

        nnbr(nearer_kd, target, nearer_hr, max_dist_sqd, lev + 1, K, nnl);

        KDNode nearest = (KDNode) nnl.getHighest();
        double dist_sqd;

        if (!nnl.isCapacityReached()) {
            dist_sqd = Double.MAX_VALUE;
        } else {
            dist_sqd = nnl.getMaxPriority();
        }

        max_dist_sqd = Math.min(max_dist_sqd, dist_sqd);

        HPoint closest = further_hr.closest(target);
        if (Double.valueOf(HPoint.manhanttandist(closest, target)).compareTo(max_dist_sqd) < 0) {
            if (pivot_to_target < dist_sqd) {
                nearest = kd;
                dist_sqd = pivot_to_target;
                if (!kd.deleted) {
                    nnl.insert(kd, dist_sqd);
                }
                if (nnl.isCapacityReached()) {
                    max_dist_sqd = nnl.getMaxPriority();
                } else {
                    max_dist_sqd = Double.MAX_VALUE;
                }
            }
            nnbr(further_kd, target, further_hr, max_dist_sqd, lev + 1, K, nnl);
            KDNode temp_nearest = (KDNode) nnl.getHighest();
            double temp_dist_sqd = nnl.getMaxPriority();
            if (temp_dist_sqd < dist_sqd) {
                nearest = temp_nearest;
                dist_sqd = temp_dist_sqd;
            }
        }
        else if (pivot_to_target < max_dist_sqd) {
            nearest = kd;
            dist_sqd = pivot_to_target;
        }
    }

    private static String pad(int n) {
        String s = "";
        for (int i = 0; i < n; ++i) {
            s += " ";
        }
        return s;
    }

    private static void hrcopy(HRect hr_src, HRect hr_dst) {
        hpcopy(hr_src.min, hr_dst.min);
        hpcopy(hr_src.max, hr_dst.max);
    }

    private static void hpcopy(HPoint hp_src, HPoint hp_dst) {
        for (int i = 0; i < hp_dst.coord.length; ++i) {
            hp_dst.coord[i] = hp_src.coord[i];
        }
    }

    protected String toString(int depth) {
        String s = k + "  " + v + (deleted ? "*" : "");
        if (left != null) {
            s = s + "\n" + pad(depth) + "L " + left.toString(depth + 1);
        }
        if (right != null) {
            s = s + "\n" + pad(depth) + "R " + right.toString(depth + 1);
        }
        return s;
    }
}

KDTree.java

/**
 * kd树
 */
public class KDTree<T> implements java.io.Serializable {
    //维度
    private int m_K;

    //根节点
    private KDNode m_root;

    //树里的节点个数
    private int m_count;

    //创建一个k维的k-d树
    public KDTree(int k) {
        m_K = k;
        m_root = null;
    }

    //向kd树里插入一个节点
    //key是k维的值
    //value是节点的标签
    public void insert(double[] key, T value) {
        if (key.length != m_K) {
            throw new RuntimeException("KDTree: wrong key size!");
        } else {
            m_root = KDNode.ins(new HPoint(key), value, m_root, 0, m_K);
        }
        m_count++;
    }


    //根据key数值,搜索kd树节点
    public Object search(double[] key) {
        if (key.length != m_K) {
            throw new RuntimeException("KDTree: wrong key size!");
        }

        KDNode kd = KDNode.srch(new HPoint(key), m_root, m_K);

        return (kd == null ? null : kd.v);
    }

    //删除kd树节点
    public void delete(double[] key) {

        if (key.length != m_K) {
            throw new RuntimeException("KDTree: wrong key size!");
        } else {

            KDNode t = KDNode.srch(new HPoint(key), m_root, m_K);
            if (t == null) {
                throw new RuntimeException("KDTree: key missing!");
            } else {
                t.deleted = true;
            }

            m_count--;
        }
    }

    //搜索距离最近的kd树节点
    public T nearest(double[] key) {
        List<T> nbrs = nearest(key, 1);
        return nbrs.get(0);
    }

    //搜索最近的n个kd树节点
    public List<T> nearest(double[] key, int n) {
        if (n < 0 || n > m_count) {
            throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot"
                    + " be negative or greater than number of nodes (" + m_count + ").");
        }

        if (key.length != m_K) {
            throw new RuntimeException("KDTree: wrong key size!");
        }

        List<T> nbrs = new ArrayList<T>(n);
        NearestNeighborList nnl = new NearestNeighborList(n);

        HRect hr = HRect.infiniteHRect(key.length);
        double max_dist_sqd = Double.MAX_VALUE;
        HPoint keyp = new HPoint(key);

        KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl);

        for (int i = 0; i < n; ++i) {
            KDNode<T> kd = (KDNode) nnl.removeHighest();
            nbrs.add(kd.v);
        }

        return nbrs;
    }

    private double mandist(double[] p1,double[] p2){
        double dist = 0.0;
        for(int i=0;i<p1.length;i++){
            dist += Math.abs(p1[i]-p2[i]);
        }
        return dist;
    }

    //搜索最近的n个kd树节点,返回与他们的的距离
    public List<Double> nearestDistance(double[] key, int n) {
        if (n < 0 || n > m_count) {
            throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot"
                    + " be negative or greater than number of nodes (" + m_count + ").");
        }

        if (key.length != m_K) {
            throw new RuntimeException("KDTree: wrong key size!");
        }

        List<Double> nbrs = new ArrayList<Double>(n);
        NearestNeighborList nnl = new NearestNeighborList(n);

        HRect hr = HRect.infiniteHRect(key.length);
        double max_dist_sqd = Double.MAX_VALUE;
        HPoint keyp = new HPoint(key);

        KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl);

        for (int i = 0; i < n; ++i) {
            KDNode<T> kd = (KDNode) nnl.removeHighest();
            nbrs.add(mandist(kd.k.coord,key));
        }

        return nbrs;
    }

    public String toString() {
        return m_root.toString(0);
    }
}

NeareastNeighbor.java

/**
 * 最近邻居列表,基于优先队列实现
 */
public class NearestNeighborList implements Serializable {
    public static int REMOVE_HIGHEST = 1;
    public static int REMOVE_LOWEST = 2;

    PriorityQueue m_Queue = null;
    int m_Capacity = 0;

    //只保存最近的capacity个邻居
    public NearestNeighborList(int capacity) {
        m_Capacity = capacity;
        m_Queue = new PriorityQueue(m_Capacity, Double.POSITIVE_INFINITY);
    }

    public double getMaxPriority() {
        if (m_Queue.length() == 0) {
            return Double.POSITIVE_INFINITY;
        }
        return m_Queue.getMaxPriority();
    }

    public boolean insert(Object object, double priority) {
        if (m_Queue.length() < m_Capacity) {
            //如果尚未达到capacity个,则直接放入队列
            m_Queue.add(object, priority);
            return true;
        }
        if (priority > m_Queue.getMaxPriority()) {
            //如果优先级比队列里的其他元素都大,则入不了队列
            return false;
        }
        //移除队列中优先级最大的元素,即队尾元素
        m_Queue.remove();
        //将新元素插入
        m_Queue.add(object, priority);
        return true;
    }

    public boolean isCapacityReached() {
        return m_Queue.length() >= m_Capacity;
    }

    public Object getHighest() {
        return m_Queue.front();
    }

    public boolean isEmpty() {
        return m_Queue.length() == 0;
    }

    public int getSize() {
        return m_Queue.length();
    }

    public Object removeHighest() {
        return m_Queue.remove();
    }
}

PriorityQueue.java

/**
 * 优先队列,优先级越低的越在队列前
 */
public class PriorityQueue implements Serializable {
    private double maxPriority = Double.MAX_VALUE;
    private Object[] data;
    private double[] value;
    private int count;
    private int capacity;

    public PriorityQueue() {
        init(20);
    }

    public PriorityQueue(int capacity) {
        init(capacity);
    }

    public PriorityQueue(int capacity, double maxPriority) {
        this.maxPriority = maxPriority;
        init(capacity);
    }

    private void init(int size) {
        capacity = size;
        data = new Object[capacity + 1];
        value = new double[capacity + 1];
        value[0] = maxPriority;
        data[0] = null;
    }

    public void add(Object element, double priority) {
        if (count++ >= capacity) {
            expandCapacity();
        }
        value[count] = priority;
        data[count] = element;
        bubbleUp(count);
    }

    public Object remove() {
        if (count == 0) {
            return null;
        }
        Object element = data[1];
        data[1] = data[count];
        value[1] = value[count];
        data[count] = null;
        value[count] = 0L;
        count--;
        bubbleDown(1);
        return element;
    }

    public Object front() {
        return data[1];
    }

    public double getMaxPriority() {
        return value[1];
    }

    private void bubbleDown(int pos) {
        Object element = data[pos];
        double priority = value[pos];
        int child;
        for (; pos * 2 <= count; pos = child) {
            child = pos * 2;
            if (child != count) {
                if (value[child] < value[child + 1]) {
                    child++;
                }
            }
            if (priority < value[child]) {
                value[pos] = value[child];
                data[pos] = data[child];
            } else {
                break;
            }
        }
        value[pos] = priority;
        data[pos] = element;
    }

    private void bubbleUp(int pos) {
        Object element = data[pos];
        double priority = value[pos];
        while (value[pos / 2] < priority) {
            value[pos] = value[pos / 2];
            data[pos] = data[pos / 2];
            pos /= 2;
        }
        value[pos] = priority;
        data[pos] = element;
    }

    private void expandCapacity() {
        capacity = count * 2;
        Object[] elements = new Object[capacity + 1];
        double[] prioritys = new double[capacity + 1];
        System.arraycopy(data, 0, elements, 0, data.length);
        System.arraycopy(value, 0, prioritys, 0, data.length);
        data = elements;
        value = prioritys;
    }

    public void clear() {
        for (int i = 1; i < count; i++) {
            data[i] = null;

        }
        count = 0;
    }

    public int length() {
        return count;
    }
}


请注意,为了性能,使用的距离均是曼哈顿距离。
我们简单的写个程序测一下性能,测试代码如下:

/**
 * KD树k近邻搜索
 */
public class KDTreeTest {
    public double mandist(Double[] p1,Double[] p2){
        double dist = 0.0;
        for(int i=0;i<p1.length;i++){
            dist += Math.abs(p1[i]-p2[i]);
        }
        return dist;
    }

    public void test(){
        List<Double> list1 = new ArrayList<Double>();
        List<Double> list2 = new ArrayList<Double>();

        int k=100;
        System.out.println("intput size:"+k);
        int dimension = 4;
        //初始化100个4维变量
        List<Double[]> list = new ArrayList<Double[]>();
        for(int i=0;i<k;i++){
            Double[] arr = new Double[dimension];
            for(int j=0;j<4;j++){
                arr[j] = Math.random();
            }
            list.add(arr);
        }

        System.out.println("****************kdtree********************");
        long time1 = System.currentTimeMillis();
        //计算每个点最近的第3个点
        KDTree<Integer> kdTree = new KDTree<Integer>(dimension);
        for(int i=0;i<k;i++){
            double[] curDouble = new double[dimension];
            int index = -1;
            for(Double item:list.get(i)){
                curDouble[++index] = item;
            }
            kdTree.insert(curDouble,i);
        }
        int nearest = 8;
        for(int i=0;i<k;i++){
            double[] curDouble = new double[dimension];
            int index = -1;
            for(Double item:list.get(i)){
                curDouble[++index] = item;
            }
            List<Double> distance = kdTree.nearestDistance(curDouble, nearest + 1);
            Collections.sort(distance);
            //System.out.println(distance.get(distance.size() - 1));

            list1.add(distance.get(distance.size()-1));
        }
        long time2 = System.currentTimeMillis();
        System.out.println((time2-time1)+"ms");
        System.out.println("****************kdtree********************");

        System.out.println("****************normal********************");
        long time3 = System.currentTimeMillis();
        //计算每个点最近的第3个点
        for(int i=0;i<k;i++){
            List<Double> distance = new ArrayList<Double>();
            for(int j=0;j<k;j++){
                distance.add(mandist(list.get(i), list.get(j)));
            }
            Collections.sort(distance);

            //System.out.println(distance.get(nearest));

            list2.add(distance.get(nearest));
        }
        long time4 = System.currentTimeMillis();
        System.out.println((time4-time3)+"ms");
        System.out.println("****************normal********************");

        boolean same = true;
        for(int i=0;i<list1.size();i++){
            if(!list1.get(i).equals(list2.get(i))){
                same = false;
                break;
            }
        }
        if(same){
            System.out.println("result is same");
        }else{
            System.out.println("result not same");
        }
    }

    public static void main(String[] args){
        new KDTreeTest().test();
    }
}

各个数据量级别的比对如下(kdtree和硬查的方法):




10万级别时,普通硬查方法已经几分钟无法算出结果了。


60万级别时,我们单看kdtree:


那么如何用在spark程序里呢?用法如下:

double eps = Double.MAX_VALUE;
        final KDTree<Integer> kdtree = new KDTree<Integer>(4);
        scala.reflect.ClassTag<KDTree<Integer>> curClassTag = scala.reflect.ClassTag$.MODULE$.apply(KDTree.class);
        try{
            JavaRDD<Row> rowRdd = df.select("fwd_ppf", "fwd_bpp", "recv_ppf", "recv_bpp").toJavaRDD();
            List<Row> list = rowRdd.collect();
            for(int i=0;i<list.size();i++){
                kdtree.insert(new double[]{
                        Double.valueOf(String.valueOf(list.get(i).get(0))),
                        Double.valueOf(String.valueOf(list.get(i).get(1))),
                        Double.valueOf(String.valueOf(list.get(i).get(2))),
                        Double.valueOf(String.valueOf(list.get(i).get(3)))
                },i);
            }

            //生成kdtree广播变量
            final Broadcast<KDTree<Integer>> broadCast = sqlContext.sparkContext().broadcast(kdtree,curClassTag);
            JavaRDD<Double> sortRDD = rowRdd.map(new Function<Row, Double>() {
                public Double call(Row row) throws Exception {
                    double distance = Double.MAX_VALUE;
                    try {
                        //找到第minPts近的节点
                        List<Double> list = broadCast.getValue().nearestDistance(new double[]{
                                Double.valueOf(String.valueOf(row.get(0))),
                                Double.valueOf(String.valueOf(row.get(1))),
                                Double.valueOf(String.valueOf(row.get(2))),
                                Double.valueOf(String.valueOf(row.get(3)))
                        }, minPts + 1);
                        Collections.sort(list);
                        distance = list.get(list.size() - 1);
                    } catch (Exception ex) {
                        logger.error("", ex);
                    }
                    return distance;
                }
            });

            sortRDD.persist(StorageLevel.MEMORY_AND_DISK());


spark运行结果如下:



目前的弊端为,构造树时需要collect数据,我们算一下driver需要承担的数据量:

4个double数据,6000万情况下,所占字节数:


所以driver端的内存最好配得多一些,我觉得4G比较保险。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值