代码实现细节几乎照抄 马克·艾伦·维斯 一书 😂《数据结构与算法:Java语言描述》,算法分析强烈推荐此书。
不过自己写的工具类和测试类挺好用的😊,还写了一个多线程的排序算法(子线程用的快排),测试出的速度没单独的快排快😭,需要优化下🛠🛠。
目录
主内容
源码
Sort
package com.onemsg.algorithm.sort;
/**
* Sort 为了方便测试而加入的函数式接口
*/
@FunctionalInterface
public interface Sort {
long[] sort(long[] array);
}
Sorts
package com.onemsg.algorithm.sort;
/**
* Sort
*/
public interface Sorts {
static long[] sort(long[] array){
if(array.length < 100 ){
return InsertionSort.sort(array);
}else if(array.length < 10000){
return ShellSort.sort(array);
}else{
return QuickSort.sort(array);
}
}
}
InsertionSort
package com.onemsg.algorithm.sort;
import java.lang.reflect.Array;
/**
* InsertionSort 插入排序
*/
public class InsertionSort{
private InsertionSort(){
}
public static long[] sort(long[] array){
return sort(array, true);
}
public static double[] sort(double[] array){
return sort(array, true);
}
public static <T extends Comparable<? super T>> T[] sort(T[] array){
return sort(array, true);
}
public static long[] sort(long[] array, boolean replace){
long[] arr = array;
if(!replace){ //如果不替换原数组,则拷贝一份原数组
arr = new long[array.length];
System.arraycopy(array, 0, arr, 0, array.length);
}
int j;
for(int i = 1; i < arr.length; i++){
long tmp = arr[i];
for(j = i; j > 0 && tmp < arr[j-1]; j--){
arr[j] = arr[j-1];
}
arr[j] = tmp;
}
return arr;
}
public static double[] sort(double[] array, boolean replace){
double[] arr = array;
if (!replace) { // 如果不替换原数组,则拷贝一份原数组
arr = new double[array.length];
System.arraycopy(array, 0, arr, 0, array.length);
}
int j;
for (int i = 1; i < arr.length; i++) {
double tmp = arr[i];
for (j = i; j > 0 && tmp < arr[j - 1]; j--) {
arr[j] = arr[j - 1];
}
arr[j] = tmp;
}
return arr;
}
@SuppressWarnings("unchecked")
public static <T extends Comparable<? super T>> T[] sort(T[] array, boolean replace){
T[] arr = array;
if (!replace) { // 如果不替换原数组,则拷贝一份原数组
arr = (T[]) Array.newInstance(arr[0].getClass(), array.length);
System.arraycopy(array, 0, arr, 0, array.length);
}
int j;
for (int i = 1; i < arr.length; i++) {
T tmp = arr[i];
for (j = i; j > 0 && tmp.compareTo(arr[j-1]) < 0; j--) {
arr[j] = arr[j - 1];
}
arr[j] = tmp;
}
return arr;
}
}
ShellSort
package com.onemsg.algorithm.sort;
import java.lang.reflect.Array;
/**
* ShellSort 希尔排序
*/
public class ShellSort {
private ShellSort(){
}
public static long[] sort(long[] array) {
return sort(array, true);
}
public static double[] sort(double[] array) {
return sort(array, true);
}
public static <T extends Comparable<? super T>> T[] sort(T[] array) {
return sort(array, true);
}
public static long[] sort(long[] array, boolean replace){
long[] arr = array;
if(!replace){ //如果不替换原数组,则拷贝一份原数组
arr = new long[array.length];
System.arraycopy(array, 0, arr, 0, array.length);
}
int j;
int[] seq = getHibbardSequence(array);
for(int gap : seq){
for(int i = gap; i < arr.length; i++){
long tmp = arr[i];
for (j = i; j >= gap && tmp < arr[j - gap]; j -= gap) {
arr[j] = arr[j - gap];
}
arr[j] = tmp;
}
}
return arr;
}
public static double[] sort(double[] array, boolean replace){
double[] arr = array;
if(!replace){ //如果不替换原数组,则拷贝一份原数组
arr = new double[array.length];
System.arraycopy(array, 0, arr, 0, array.length);
}
int j;
int[] seq = getHibbardSequence(array);
for(int gap : seq){
for(int i = gap; i < arr.length; i++){
double tmp = arr[i];
for (j = i; j >= gap && tmp < arr[j - gap]; j -= gap) {
arr[j] = arr[j - gap];
}
arr[j] = tmp;
}
}
return arr;
}
@SuppressWarnings("unchecked")
public static <T extends Comparable<? super T>> T[] sort(T[] array, boolean replace){
T[] arr = array;
if (!replace) { // 如果不替换原数组,则拷贝一份原数组
arr = (T[]) Array.newInstance(arr[0].getClass(), array.length);
System.arraycopy(array, 0, arr, 0, array.length);
}
int j;
int[] seq = getHibbardSequence(array);
for(int gap : seq){
for(int i = gap; i < arr.length; i++){
T tmp = arr[i];
for (j = i; j >= gap && tmp.compareTo(arr[j - gap]) < 0; j -= gap) {
arr[j] = arr[j - gap];
}
arr[j] = tmp;
}
}
return arr;
}
// 得到Hibbard增量序列,n_k = 2^k - 1
private static int[] getHibbardSequence(Object array){
int N = Array.getLength(array);
int seqSize = (int) (Math.log(N/2 + 1) / Math.log(2));
int[] seq = new int[seqSize];
int lastIndex = seqSize - 1;
for(int i = 0; i < seqSize; i++){
seq[lastIndex-i] = (int) Math.pow(2, i+1) - 1;
}
return seq;
}
}
MergeSort
package com.onemsg.algorithm.sort;
/**
* MergeSort | 归并排序
*/
public class MergeSort {
public static long[] sort(long[] array){
long[] tempArray = new long[array.length];
sort(array, tempArray, 0, array.length - 1);
return array;
}
private static void sort(long[] array, long[] tempArray, int left, int right) {
if( left < right){
int center = (left + right) / 2;
sort(array, tempArray, left, center);
sort(array, tempArray, center + 1, right);
merge(array, tempArray, left, center + 1, right);
}
}
private static void merge(long[] array, long[] tempArray, int leftPos, int rightPos, int rightEnd){
int leftEnd = rightPos - 1;
int tmpPos = leftPos;
int numElements = rightEnd - leftPos + 1;
// Main loop
while( leftPos <= leftEnd && rightPos <= rightEnd ){
if( array[leftPos] <= array[rightPos] ){
tempArray[tmpPos++] = array[leftPos++];
}else{
tempArray[tmpPos++] = array[rightPos++];
}
}
while( leftPos <= leftEnd ){
tempArray[ tmpPos++ ] = array[leftPos++];
}
while( rightPos <= rightEnd){
tempArray[ tmpPos++ ] = array[ rightPos++ ];
}
for(int i = 0; i < numElements; i++, rightEnd--){
array[rightEnd] = tempArray[rightEnd];
}
}
}
QuickSort
package com.onemsg.algorithm.sort;
/**
* QuickSort
*/
public class QuickSort {
public final static int CUTOFF = 9;
public static long[] sort(long[] array){
return sort(array, true);
}
public static long[] sort(long[] array, boolean replace){
long[] arr = array;
if(!replace){ //如果不替换原数组,则拷贝一份原数组
arr = new long[array.length];
System.arraycopy(array, 0, arr, 0, array.length);
}
quicksort(arr, 0, arr.length-1);
return arr;
}
/**
* 用来递归调用的内置快速排序方法.
* 使用三数中值来分区,和一个10以内的CUTOFF来判断小数组
* @param array 用来排序的数组
* @param left
* @param right
*/
private static void quicksort(long[] array, int left, int right){
if( left + CUTOFF <= right){
long pivot = median3(array, left, right); //枢纽元
// 开始分区
int i = left; //左指针
int j = right -1; //右指针
while (true){
while( array[++i] < pivot ) { }
while( array[--j] > pivot ) { }
if( i < j){
swapLong(array, i, j);
}else{
break;
}
}
swapLong(array, i, right-1); //重新存储枢纽元
quicksort(array, left, i-1);
quicksort(array, i+1, right);
}else{
insertionSort(array, left, right);
}
}
private static void insertionSort(long[] array, int left, int right){
int j;
for(int i = left+1; i <= right; i++){
long temp = array[i];
for(j = i; j > left && temp < array[j-1]; j--){
array[j] = array[j-1];
}
array[j] = temp;
}
}
private static long median3(long[] array, int left, int right){
int center = (left + right) / 2;
if(array[center] < array[left]){
swapLong(array, left, center);
}
if(array[right] < array[left]){
swapLong(array, left, right);
}
if(array[right] < array[center]){
swapLong(array, center, right);
}
//把三数中值当道倒数第二位置
swapLong(array, center, right-1);
return array[right-1];
}
private static void swapLong(long[] array, int a, int b){
long temp = array[a];
array[a] = array[b];
array[b] = temp;
}
}
BucketSort
package com.onemsg.algorithm.sort;
/**
* BucketSort | 桶排序
*/
public class BucketSort {
public static int[] sort(int[] array){
// 找到最大最小值
int max = Integer.MIN_VALUE;
int min = Integer.MAX_VALUE;
for (int l : array) {
if(max < l) max = l;
if(min > l) min = l;
}
// 初始化桶和填桶
byte[] buckets = new byte[max - min + 1];
for (int i = 0; i < array.length; i++) {
buckets[array[i] - min]++;
}
// 遍历桶
int i_array = 0;
for (int i = 0; i < buckets.length; i++) {
while( buckets[i]-- != 0){
array[i_array++] = i + min;
}
}
return array;
}
public static long[] sort(long[] array) {
// 找到最大最小值
long max = Long.MIN_VALUE;
long min = Long.MAX_VALUE;
for (long l : array) {
if (max < l)
max = l;
if (min > l)
min = l;
}
// 初始化桶和填桶
byte[] buckets = new byte[(int) (max - min + 1)];
for (int i = 0; i < array.length; i++) {
buckets[(int) (array[i] - min)]++;
}
// 遍历桶
int i_array = 0;
for (int i = 0; i < buckets.length; i++) {
while (buckets[i] != 0) {
array[i_array++] = i + min;
buckets[i]--;
}
}
return array;
}
}
RadixSort
package com.onemsg.algorithm.sort;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
/**
* RadixSort | 基础排序
*/
public class RadixSort {
public static long[] sort(long[] array) {
return sort(array, true);
}
@Deprecated
public static long[] sort(long[] array, boolean replace) {
long[] arr = array;
if (!replace) { // 如果不替换原数组,则拷贝一份原数组
arr = new long[array.length];
System.arraycopy(array, 0, arr, 0, array.length);
}
// 初始化 10 个桶
List<List<Long>> buckets = new ArrayList<>(10);
for (int i = 0; i < 10; i++) {
buckets.add(new ArrayList<>());
}
long max = Long.MIN_VALUE;
for (int i = 0; i < arr.length; i++) {
if (max < arr[i]) {
max = arr[i];
}
}
int p_max = (int) Math.log10(max) + 1;
int p = 1; // 第几躺
int radix = 1;
while (p <= p_max) {
// radix = radix * 10;
for (int i = 0; i < arr.length; i++) {
buckets.get((int) arr[i] / radix % 10).add(arr[i]);
}
int i = 0;
for (List<Long> bucket : buckets) {
for (Long n : bucket) {
arr[i++] = n;
}
bucket.clear();
}
p++;
radix *= 10;
}
return arr;
}
/**并行桶排序 */
public static long[] parallelBucketsort(long[] array) {
final int N = array.length < 10000 ? 1 : 10; // 桶的个数
List<List<Long>> buckets = new ArrayList<>(N);
IntStream.range(0, N).forEach(n -> buckets.add(new ArrayList<>()));
long min = Long.MAX_VALUE;
long max = Long.MIN_VALUE;
for (int i = 0; i < array.length; i++) {
if (min > array[i]) {
min = array[i];
}
if (max < array[i]) {
max = array[i];
}
}
long step = (max - min) / (N);
int lastIndex = N - 1;
for (long n : array) {
buckets.get(getBucketIndex(n, min, lastIndex, step)).add(n);
}
List<long[]> bucketsArray = new ArrayList<>();
for (int i = 0, end = buckets.size(); i < end; i++) {
bucketsArray.add(buckets.get(i).stream().mapToLong(n -> (long) n).toArray());
buckets.set(i, null);
}
ThreadGroup group = new ThreadGroup("排序组");
IntStream.range(0, N)
.forEach(i -> new Thread(group, () -> ShellSort.sort(bucketsArray.get(i)), "子线程--桶-" + i).start());
while(group.activeCount() != 0){ }
return bucketsArray.stream().flatMapToLong(arr -> LongStream.of(arr)).toArray();
}
private static int getBucketIndex(long n, long min, int lastIndex, long step){
int index = (int) ( (n - min) / step );
return index <= lastIndex ? index : lastIndex;
}
/**字符串的基数排序,需要所有字符串长度相等,且为 ASCII*/
public static String[] sortString(String[] array, int stringLen){
final int BUCKETS = 256;
List<List<String>> buckets = new ArrayList<>(BUCKETS);
IntStream.range(0, BUCKETS).forEach( i -> buckets.add(new ArrayList<>()) );
// List<String>[] ls = new ArrayList[]<String>();
for(int pos = stringLen - 1; pos >= 0; pos--){
for( String s : array){
buckets.get(s.charAt(pos)).add(s);
}
int index = 0;
for(List<String> bucket : buckets){
for (String s : bucket) {
array[index++] = s;
}
bucket.clear();
}
}
return array;
}
}
HeapSort
package com.onemsg.algorithm.sort;
/**
* BucketSort | 桶排序
*/
public class BucketSort {
public static int[] sort(int[] array){
// 找到最大最小值
int max = Integer.MIN_VALUE;
int min = Integer.MAX_VALUE;
for (int l : array) {
if(max < l) max = l;
if(min > l) min = l;
}
// 初始化桶和填桶
byte[] buckets = new byte[max - min + 1];
for (int i = 0; i < array.length; i++) {
buckets[array[i] - min]++;
}
// 遍历桶
int i_array = 0;
for (int i = 0; i < buckets.length; i++) {
while( buckets[i]-- != 0){
array[i_array++] = i + min;
}
}
return array;
}
public static long[] sort(long[] array) {
// 找到最大最小值
long max = Long.MIN_VALUE;
long min = Long.MAX_VALUE;
for (long l : array) {
if (max < l)
max = l;
if (min > l)
min = l;
}
// 初始化桶和填桶
byte[] buckets = new byte[(int) (max - min + 1)];
for (int i = 0; i < array.length; i++) {
buckets[(int) (array[i] - min)]++;
}
// 遍历桶
int i_array = 0;
for (int i = 0; i < buckets.length; i++) {
while (buckets[i] != 0) {
array[i_array++] = i + min;
buckets[i]--;
}
}
return array;
}
}
ParallelSort
package com.onemsg.algorithm.sort;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
/**
* ParallelSort
* 自定义的多线程排序,分组和转换太耗时了,还需要优化
*/
public class ParallelSort {
private ParallelSort(){
}
public static long[] sort(long[] array){
return parallelBucketsort(array);
}
/** 并行桶排序 */
public static long[] parallelBucketsort(long[] array) {
final int N = array.length < 10000 ? 1 : 10; // 桶的个数
List<List<Long>> buckets = createBuckets(array, N);
List<long[]> bucketsArray = new ArrayList<>();
for (int i = 0, end = buckets.size(); i < end; i++) {
bucketsArray.add(buckets.get(i).stream().mapToLong(n -> (long) n).toArray());
buckets.set(i, null);
}
ThreadGroup group = new ThreadGroup("排序组");
IntStream.range(0, N)
.forEach(i -> new Thread(group, () -> QuickSort.sort(bucketsArray.get(i)), "子线程--桶-" + i).start());
while (group.activeCount() != 0) { }
return bucketsArray.stream().flatMapToLong(arr -> LongStream.of(arr)).toArray();
}
private static List<List<Long>> createBuckets(long[] array, int N){
List<List<Long>> buckets = new ArrayList<>(N);
IntStream.range(0, N).forEach(n -> buckets.add(new ArrayList<>()));
long min = Long.MAX_VALUE;
long max = Long.MIN_VALUE;
for (int i = 0; i < array.length; i++) {
if (min > array[i]) {
min = array[i];
}
if (max < array[i]) {
max = array[i];
}
}
long step = (max - min) / (N);
int lastIndex = N - 1;
for (long n : array) {
buckets.get(getBucketIndex(n, min, lastIndex, step)).add(n);
}
return buckets;
}
private static int getBucketIndex(long n, long min, int lastIndex, long step) {
int index = (int) ((n - min) / step);
return index <= lastIndex ? index : lastIndex;
}
public static void main(String[] args) {
}
}
工具类 | 测试类
ArrayUntil
package com.onemsg.util;
/**
* ArrayUtil | 方便打印数组
*/
public class ArrayUtil {
public static void main(final String[] args) {
final long[] a = new long[] { 1, 2, 3, 4, 5, 6, 7, 8 };
System.out.println(toString(a, 0, a.length));
System.out.println(toString(a, 0, -2));
System.out.println(toString(a, 3, 3));
System.out.println(toString(a, -5, a.length));
}
public static String toString(final long[] a) {
return toString(a, 0, a.length);
}
public static String toString(final long[] a, final int end) {
return toString(a, 0, end);
}
public static String toString(final int[] a) {
return toString(a, 0, a.length);
}
public static String toString(final int[] a, final int end) {
return toString(a, 0, end);
}
public static String toString(final float[] a) {
return toString(a, 0, a.length);
}
public static String toString(final float[] a, final int end) {
return toString(a, 0, end);
}
public static String toString(final double[] a) {
return toString(a, 0, a.length);
}
public static String toString(final double[] a, final int end) {
return toString(a, 0, end);
}
public static String toString(final Object[] a) {
return toString(a, 0, a.length);
}
public static String toString(final Object[] a, final int end) {
return toString(a, 0, end);
}
public static String toString(final long[] a, int start, final int end) {
if (a == null)
return "null";
start = start >= 0 ? start : a.length + start;
final int iMax = end > 0 ? end - 1 : a.length + end - 1;
if (start >= iMax || iMax == -1) {
return "[]";
}
final StringBuilder b = new StringBuilder();
b.append('[');
for (int i = start; ; i++) {
b.append(a[i]);
if (i == iMax)
return b.append(']').toString();
b.append(", ");
}
}
public static String toString(final int[] a, int start, final int end) {
if (a == null)
return "null";
start = start >= 0 ? start : a.length + start;
final int iMax = end > 0 ? end - 1 : a.length + end - 1;
if (start >= iMax || iMax == -1) {
return "[]";
}
final StringBuilder b = new StringBuilder();
b.append('[');
for (int i = start;; i++) {
b.append(a[i]);
if (i == iMax)
return b.append(']').toString();
b.append(", ");
}
}
public static String toString(final float[] a, int start, final int end) {
if (a == null)
return "null";
start = start >= 0 ? start : a.length + start;
final int iMax = end > 0 ? end - 1 : a.length + end - 1;
if (start >= iMax || iMax == -1) {
return "[]";
}
final StringBuilder b = new StringBuilder();
b.append('[');
for (int i = start;; i++) {
b.append(a[i]);
if (i == iMax)
return b.append(']').toString();
b.append(", ");
}
}
public static String toString(final double[] a, int start, final int end) {
if (a == null)
return "null";
start = start >= 0 ? start : a.length + start;
final int iMax = end > 0 ? end - 1 : a.length + end - 1;
if (start >= iMax || iMax == -1) {
return "[]";
}
final StringBuilder b = new StringBuilder();
b.append('[');
for (int i = start;; i++) {
b.append(a[i]);
if (i == iMax)
return b.append(']').toString();
b.append(", ");
}
}
public static String toString(final Object[] a, int start, final int end) {
if (a == null)
return "null";
start = start >= 0 ? start : a.length + start;
final int iMax = end > 0 ? end - 1 : a.length + end - 1;
if (start >= iMax || iMax == -1) {
return "[]";
}
final StringBuilder b = new StringBuilder();
b.append('[');
for (int i = start;; i++) {
b.append(a[i]);
if (i == iMax)
return b.append(']').toString();
b.append(", ");
}
}
}
TestSort
package com.onemsg.algorithm.sort;
import java.util.Arrays;
import java.util.Random;
/**
* TestSort
*/
public class TestSort {
public static void main(String[] args) {
long[] array = new Random(2020).longs(1000_0000, 1, 1000_0000).toArray();
// runShellSort(array);
testSpendTime(array, QuickSort::sort, "快速排序");
array = null;
array = new Random(2020).longs(1000_0000, 1, 1000_0000).toArray();
testSpendTime(array, ParallelSort::sort, "多线程排序");
array = null;
runHeapSort();
}
public static void runHeapSort(){
Integer[] array = new Random(2020).ints(1000_0000, 1, 1000_0000).boxed().toArray(Integer[]::new);
System.out.printf("排序算法: %s, 数组大小: %d, 数组前十位为:\n", "堆排序", array.length);
System.out.println(Arrays.toString(Arrays.copyOf(array, 10)));
long startTime = System.currentTimeMillis();
array = HeapSort.sort(array);
long endTime = System.currentTimeMillis();
System.out.println("排序后数据前十位为:");
System.out.println(Arrays.toString(Arrays.copyOf(array, 10)));
System.out.printf("spend time: %d ms\n", endTime - startTime);
System.out.println("- ".repeat(30));
}
public static void runJavaSort(){
long[] array = new Random(2020).longs(10000_0000, 1, 1000_0000).toArray();
long startTime = System.currentTimeMillis();
Arrays.parallelSort(array);
long endTime = System.currentTimeMillis();
System.out.printf("spend time: %d ms\n", endTime - startTime);
}
public static void runParallelSort(long[] array) {
testSpendTime(array, ParallelSort::sort, "自制的多线程排序");
}
public static void testSpendTime(long[] array, Sort sort, String sortName){
System.out.printf("排序算法: %s, 数组大小: %d, 数组前十位为:\n", sortName, array.length);
System.out.println(Arrays.toString(Arrays.copyOf(array, 10)));
long startTime = System.currentTimeMillis();
array = sort.sort(array);
long endTime = System.currentTimeMillis();
System.out.println("排序后数据前十位为:");
System.out.println(Arrays.toString(Arrays.copyOf(array, 10)));
System.out.printf("spend time: %d ms\n", endTime - startTime);
System.out.println("- ".repeat(30));
}
}
测试结果
待排序的数组为大小为一千万的 long[] , 显然快排最快。推排序之所以花了那么长时间是因为 实现的算法基于泛型的,也就是 Long 、Integer 等对象,所以消耗的时间很多。
总结
1. 开始排序算法提供了大量的方法重载(所以才占了那么多页面😂),后面的为了懒省事就不写了,理解算法本身含义最重要嘛😘。
2. 桶排序在某些情况下是最快的,只用常数时间,我测试了排序 一亿大小的long[],只花了两秒。
3. Java平常使用排序时,就用 Arrays.sort() 或 Arrays.parallelSort() , 他们基于快排实现,提供了大量的方法重载,第二个是并行流。
4. 当然还有好动要总结,关于各个排序的,以后有空在补😁。
代码测试和并行类中用到了Stream, 感兴趣的小伙伴可以在评论区一起讨论哦~ ~ 😄😜