package com;
public class MergeSort {
public static int inversionCounts = 0;
private static Integer[] unsortedArray;
public static Integer[] getUnsortedArray() {
return unsortedArray;
}
public static void setUnsortedArray(Integer[] unsortedArray) {
MergeSort.unsortedArray = unsortedArray;
}
public static Integer[] mergeSort(int from, int end)
{
// zero case
if (unsortedArray.length == 0)
return null;
// Base case
if (end == from)
return new Integer[] { unsortedArray[end] };
//split array into half and half
int middle = (from + end) / 2;
Integer[] halfArray = mergeSort(from, middle);
Integer[] anotherHalfArray = mergeSort(middle + 1, end);
//After this step, we assume that we got 2 sorted arrays already
//New temp array
int size = halfArray.length + anotherHalfArray.length;
Integer[] newArray = new Integer[size];
//Compare those two array and do the manual merge
int i = 0, j = 0, k = 0;
int leftOverInLeftHalfArray = halfArray.length;
while (i != size && j != halfArray.length && k != anotherHalfArray.length)
{
if (halfArray[j] > anotherHalfArray[k])
{
newArray[i] = halfArray[j];
j++;
leftOverInLeftHalfArray --;
}
else
{
newArray[i] = anotherHalfArray[k];
k++;
inversionCounts += leftOverInLeftHalfArray;
}
i++;
}
//Clean up
while (j != halfArray.length)
{
newArray[i++] = halfArray[j++];
}
while (k != anotherHalfArray.length)
{
newArray[i++] = anotherHalfArray[k++];
}
return newArray;
}
}
big o分析:
设数组长度为n,我们每次进行split(<span style="font-family: Arial, Helvetica, sans-serif;">mergeSort(from, middle);</span>)都会进行2次递归调用,所以在第n层,我们的递归调用规模为2^n
但是同样的,每个递归调用中的问题规模(即Compare,merge,and cleanup所操作的对象的长度)也相应变小了,为n/2^n,因此,在每一层,递归的时间消耗为n/2^n * 2^n = n
递归树高度为log2 n,所以递归调用总共时间消耗为n*log2n, big o: O(nlogn)
Test case:
package test;
import static org.junit.Assert.*;
import java.util.Random;
import org.junit.Before;
import org.junit.Test;
import com.MergeSort;
public class TestMergeSort {
private static Integer[] numbers;
private final static int SIZE = 20;
private final static int MAX = 20;
static
{
numbers = new Integer[SIZE];
Random generator = new Random();
for (int i = 0; i < SIZE; i++)
{
numbers[i] = generator.nextInt(MAX);
}
}
public static void printArray(Integer[] array)
{
for (int i = 0; i < array.length; i++)
System.out.print(array[i] + " ");
System.out.println();
}
public static void main(String[] args) {
printArray(numbers);
long startTime = System.currentTimeMillis();
MergeSort.setUnsortedArray(numbers);
numbers = MergeSort.mergeSort(0, SIZE - 1);
long stopTime = System.currentTimeMillis();
long time = stopTime - startTime;
printArray(numbers);
System.out.print("Elapsed time:" + time);
assert (true);
System.out.println("Inversion Counts:" + MergeSort.inversionCounts);
}
}
这里同时做的还有整个数组中Inversion的统计,用到了LeftSideInversion,RightSideInversion和SplitInversion的概念,具体的解释看这里
https://cgi.csc.liv.ac.uk/~martin/teaching/comp202/Java/Inversions-code.html