近邻搜索之制高点树(VP-Tree)

本文介绍了制高点树(VPTree)的概念,这是一种用于近邻搜索的数据结构。VPTree不同于kd-tree,它的划分策略是基于选择一个数据点作为制高点,根据其他点与其的距离来划分数据。文章详细阐述了VPT的构造过程,并提出了查询算法,包括如何在给定查询点和半径条件下,有效地搜索目标区域。此外,还提到了简易实现代码。
部署运行你感兴趣的模型镜像

引子

近邻搜索是一种很基础的又相当重要的操作,除了信息检索以外,还被广泛用于计算机视觉、机器学习等领域,如何快速有效的做近邻查询一直是一项热门的研究。较早提出的方法多基于空间划分(Space Partition),最具有代表性的如kd-tree(kdt),球树等。本篇将介绍基于空间划分方法中的一种,制高点树(Vantage Point Tree,vpt),最初在1993年提出,比kdt稍晚,提供了一个不一样的建树思路。

VPT结构

和kdt一样,vpt也是一类二叉树,不同的是在每个节点的划分策略。略微回顾一下kdt,它在每个节点选择一个维度,根据数据点在该维度上的大小将数据均分为二。而在vpt中,首先从节点中选择一个数据点(可随机选)作为制高点(vp),然后算出其它点到vp的距离大小,最后根据该距离大小将数据点均分为二。建树算法如下:

  1. 选择某数据点v作为vp
  2. 计算其它点{Xi}到v的距离{Di}
  3. 求出{Di}中值M,小于M的数据点分给左子树,大于M的数据点分给右子树
  4. 递归地建立左子树和右子树
这里提供一个简单的例子如图,框中为平面上的点,其中红框为选中的vp,根据其它点到vp的距离进行了子树划分。

VPT查询算法


vpt查询是 准确近邻查询,较适合范围查询,可方便扩展为k近邻查询。

进行近邻查询时,假定查询点为q,当前的制高点为v,距离中值为M,则有如下策略搜索到q点距离小于r的点集:

(1)  若 dist(q,v)+r≥M,递归地搜索右子树(球外区域)

(2)  若 dist(q,v)-r≤M,递归地搜索左子树(球内区域)

为了方便写公式,用图片文字来进行证明,其实就是简单的三角形不等式的应用。

简易实现代码


最后上点干货,一个简易c++实现如下:
#ifndef _VPTREE_HEADER_
#define _VPTREE_HEADER_

#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <stdio.h>
#include <queue>
#include <limits>
//#include "fnn.h"

template<typename T, double (*distance)( const T&, const T& ), int (*getId)(const T&)>
class VpTree
{
public:
    VpTree() : _root(0) {}

    ~VpTree() {
        delete _root;
    }

    void create( const std::vector<T>& items ) {
        delete _root;
        _items = items;
        _root = buildFromPoints(0, items.size());
    }

    void search( const T& target, int k, std::vector<T>* results, 
        std::vector<double>* distances) 
    {
        std::priority_queue<HeapItem> heap;

        _tau = std::numeric_limits<double>::max();
        search( _root, target, k, heap );

        results->clear(); distances->clear();

        while( !heap.empty() ) {
            results->push_back( _items[heap.top().index] );
            distances->push_back( heap.top().dist );
            heap.pop();
        }

        std::reverse( results->begin(), results->end() );
        std::reverse( distances->begin(), distances->end() );
		printf("vp search dist = %f\n",distances->at(0));
		brute(target);
    }

	void search(const T& target,std::vector<T>* results,std::vector<double>* distances){
        int idx;
		double min = 1.0e+10;
		for(int i=0;i<_items.size();i++){
			double dist = distance( _items[i], target );
			if(dist<min){
				min=dist;
				idx = i;
			}
		}
		results->push_back(_items[idx]);
		distances->push_back(min);
	}

	int range_search(const T& target, double range, int *list, int &listnum){
		int hit = 0;
		for(int i=0;i<_items.size();i++){
			double dist = distance( _items[i], target );
			//debug here
			/*if(getId(_items[i])==4){
				printf("vp dist=%f range=%f\n",dist,range);
			}*/
			//-debug
			if(dist<=range){  //inside, need to check
				//list[listnum++] = getId(_items[i]);
				int id = getId(_items[i]);
				list[id] = 1;
				listnum++;
				hit++;
			}
		}
		/*_tau = range;
		rsearch( _root, target, hit, list, listnum);*/
		return hit;
	}

private:
    std::vector<T> _items;
    double _tau;

    struct Node 
    {
        int index;
        double threshold;
        Node* left;
        Node* right;

        Node() :
            index(0), threshold(0.), left(0), right(0) {}

        ~Node() {
            delete left;
            delete right;
        }
    }* _root;

    struct HeapItem {
        HeapItem( int index, double dist) :
            index(index), dist(dist) {}
        int index;
        double dist;
        bool operator<( const HeapItem& o ) const {
            return dist < o.dist;   
        }
    };

    struct DistanceComparator
    {
        const T& item;
        DistanceComparator( const T& item ) : item(item) {}
        bool operator()(const T& a, const T& b) {
            return distance( item, a ) < distance( item, b );
        }
    };

    Node* buildFromPoints( int lower, int upper )
    {
        if ( upper == lower ) {
            return NULL;
        }

        Node* node = new Node();
        node->index = lower;

        if ( upper - lower > 1 ) {

            // choose an arbitrary point and move it to the start
            int i = (int)((double)rand() / RAND_MAX * (upper - lower - 1) ) + lower;
            std::swap( _items[lower], _items[i] );

            int median = ( upper + lower ) / 2;

            // partitian around the median distance
            std::nth_element( 
                _items.begin() + lower + 1, 
                _items.begin() + median,
                _items.begin() + upper,
                DistanceComparator( _items[lower] ));

            // what was the median?
            node->threshold = distance( _items[lower], _items[median] );

            node->index = lower;
            node->left = buildFromPoints( lower + 1, median );
            node->right = buildFromPoints( median, upper );
        }

        return node;
    }

	double brute(const T& target){
		double min = 1.0e+10;
		for(int i=0;i<_items.size();i++){
			double dist = distance( _items[i], target );
			if(dist<min)
				min=dist;
		}
		return min;
		//printf("vp brute dist = %f\n",min);
	}

	void rsearch(Node* node, const T& target, int & counter, int *list, int &listnum){
		if ( node == NULL ) return;
		double dist = distance( _items[node->index], target );
		if ( dist < _tau ) {
			counter++;
			//list[ listnum++ ] = getId(_items[node->index]);
			list[getId(_items[node->index])] = 1;
		}

		if ( node->left == NULL && node->right == NULL ) {
            return;
        }

        if ( dist < node->threshold ) {
            if ( dist - _tau <= node->threshold ) {
                rsearch( node->left, target, counter, list, listnum);
            }

            if ( dist + _tau >= node->threshold ) {
                rsearch( node->right, target, counter, list, listnum );
            }

        } 
		else {
            if ( dist + _tau >= node->threshold ) {
                rsearch( node->right, target, counter, list, listnum );
            }

            if ( dist - _tau <= node->threshold ) {
                rsearch( node->left, target, counter, list, listnum);
            }
        }
	}

    void search( Node* node, const T& target, int k,
                 std::priority_queue<HeapItem>& heap )
    {
        if ( node == NULL ) return;

        double dist = distance( _items[node->index], target );
        //printf("dist=%g tau=%gn", dist, _tau );

        if ( dist < _tau ) {
            if ( heap.size() == k ) heap.pop();
            heap.push( HeapItem(node->index, dist) );
            if ( heap.size() == k ) _tau = heap.top().dist;
        }

        if ( node->left == NULL && node->right == NULL ) {
            return;
        }

        if ( dist < node->threshold ) {
            if ( dist - _tau <= node->threshold ) {
                search( node->left, target, k, heap );
            }

            if ( dist + _tau >= node->threshold ) {
                search( node->right, target, k, heap );
            }

        } else {
            if ( dist + _tau >= node->threshold ) {
                search( node->right, target, k, heap );
            }

            if ( dist - _tau <= node->threshold ) {
                search( node->left, target, k, heap );
            }
        }
    }
};

#endif

您可能感兴趣的与本文相关的镜像

Stable-Diffusion-3.5

Stable-Diffusion-3.5

图片生成
Stable-Diffusion

Stable Diffusion 3.5 (SD 3.5) 是由 Stability AI 推出的新一代文本到图像生成模型,相比 3.0 版本,它提升了图像质量、运行速度和硬件效率

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值