最近学习了左神BFPRT算法,给大家先讲个段子。
左神说他每次去美国面试,他都会拿BFPRT算法吹一吹。美国5个大佬在一个美丽的地方研究出来这个算法,他说自己热爱算法,他会BFPRT,每次去美国都会怀着朝圣的姿态去在那个地方转转。面试官一听:哇,这么厉害, 过!!!!!!!
好了,下来说说这个算法。
BFPRT算法是在进行大量数据排序求topk(前k个最大或最小的数)时最优算法。
为什么叫BFPRT算法,因为是美国5个大佬搞出来的,所以分别取他们的名字命名算法。
算法步骤大概如下:
1 将数组划分区域,5个数为一组,把数组分为若干组。(为什么是5呢?因为5能更好的收敛时间复杂度。)
2 将分开的组进行组内排序(这里我用的是插入排序),找出每组数的中位数。
3 找出的若干个中位数,单独放一组进行排序。
4 在找出的中位数组中,递归调用步骤2的过程,找出中位数组的中位数,把这个中位数叫做划分数。
5 用划分数进行partition过程(类似于快速排序的每趟的划分过程):小于划分数放左边,等于划分数放中间,大于划分数放右边。
6 将划分数与k进行比较:
k<划分数,在步骤5的左边区域递归上述过程,找出左边区域的中位数与k比较,直到与k相等,最后的划分数左边就是前k小的数;
k=划分数,那么左边区域直接为前k小的数;
k>划分数,在步骤5的右边区域递归上述过程,找出右边区域的中位数与k比较,直到与k相等,最后的划分数左边就是前k小的数;
为什么要花费这么多步骤去寻找划分数呢?假设数组长度为n,那么整个数组一共有n/5个中位数,在中位数组找出中位数组的中位数(划分数),那么就会有3*(n/10)个数比划分数小-如图红色区域所示:
我们一次就可以至少刷掉3*(n/10)、最多7*(n/10)的数量级的数据进行递归排序,节省了很多时间!
代码实现如下所示:
package com.zuoshen;
/**
* BFPRT算法求topk,时间复杂度bO(n)
*
* @author wanglongfei
* E-mail: islongfei@gmail.com
* @version 2017年8月5日
*
*/
public class BFPRT {
// 得到前k个最小的数
public static int[] getMinKNumsByBFPRT(int[] arr, int k) {
if (k < 1 || k > arr.length) {
return arr;
}
int minKth = getMinKthByBFPRT(arr, k);
int[] res = new int[k];//res前k个结果集
int index = 0;
for (int i = 0; i != arr.length; i++) {
if (arr[i] < minKth) {
res[index++] = arr[i];
}
}
for (; index != res.length; index++) {
res[index] = minKth;
}
return res;
}
//找出比k小的前k个数
public static int getMinKthByBFPRT(int[] arr, int K) {
int[] copyArr = copyArray(arr);
return select(copyArr, 0, copyArr.length - 1, K - 1);
}
//复制数组
public static int[] copyArray(int[] arr) {
int[] res = new int[arr.length];
for (int i = 0; i != res.length; i++) {
res[i] = arr[i];
}
return res;
}
//用划分值与k相比,依次递归排序
public static int select(int[] arr, int begin, int end, int i) {
if (begin == end) { //begin数组的开始 end数组的结尾 i表示要求的第k个数
return arr[begin];
}
int pivot = medianOfMedians(arr, begin, end);//找出划分值(中位数组中的中位数)
int[] pivotRange = partition(arr, begin, end, pivot);
if (i >= pivotRange[0] && i <= pivotRange[1]) {//小于放左边,=放中间,大于放右边
return arr[i];
} else if (i < pivotRange[0]) {
return select(arr, begin, pivotRange[0] - 1, i);
} else {
return select(arr, pivotRange[1] + 1, end, i);
}
}
//找出中位数组中的中位数
public static int medianOfMedians(int[] arr, int begin, int end) {
int num = end - begin + 1;
int offset = num % 5 == 0 ? 0 : 1; //分组:每组5个数,不满5个单独占一组
int[] mArr = new int[num / 5 + offset]; //mArr:中位数组成的数组
for (int i = 0; i < mArr.length; i++) { //计算分开后各数组的开始位置beginI 结束位置endI
int beginI = begin + i * 5;
int endI = beginI + 4;
mArr[i] = getMedian(arr, beginI, Math.min(end, endI));//对于最后一组(不满5个数),结束位置要选择end
}
return select(mArr, 0, mArr.length - 1, mArr.length / 2);
}
//划分过程,类似于快排
public static int[] partition(int[] arr, int begin, int end, int pivotValue) {
int small = begin - 1;
int cur = begin;
int big = end + 1;
while (cur != big) {
if (arr[cur] < pivotValue) {
swap(arr, ++small, cur++);
} else if (arr[cur] > pivotValue) {
swap(arr, cur, --big);
} else {
cur++;
}
}
int[] range = new int[2];
range[0] = small + 1;//比划分值小的范围
range[1] = big - 1; //比划分值大的范围
return range;
}
//计算中位数
public static int getMedian(int[] arr, int begin, int end) {
insertionSort(arr, begin, end);//将数组中的5个数排序
int sum = end + begin;
int mid = (sum / 2) + (sum % 2);
return arr[mid];
}
//数组中5个数排序(插入排序)
public static void insertionSort(int[] arr, int begin, int end) {
for (int i = begin + 1; i != end + 1; i++) {
for (int j = i; j != begin; j--) {
if (arr[j - 1] > arr[j]) {
swap(arr, j - 1, j);
} else {
break;
}
}
}
}
//交换元素顺序
public static void swap(int[] arr, int index1, int index2) {
int tmp = arr[index1];
arr[index1] = arr[index2];
arr[index2] = tmp;
}
//打印结果
public static void printArray(int[] arr) {
for (int i = 0; i != arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String[] args) {
int[] arr = { 6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9 };
printArray(getMinKNumsByBFPRT(arr, 10));
}
}