题目
求一个数组中最长递增子序列,要求时间复杂度尽量低。
- 测试输入
7
2 1 4 3 1 5 6 - 测试输出
4
分析1
采取从后向前的分析思路。如果已知第
i
个元素存在于最长递增子序列中,那么前
代码1
import java.util.Scanner;
public class LIS {
static int solution(int[] array, int[] lis) {
int n = array.length;
lis[0] = 1;
for (int i = 1; i < n; i++) {
int maxLen = 0;// array[i]之前的最长子序列长度
for (int k = 0; k < i; k++) {
if (array[k] < array[i] && lis[k] > maxLen) {
maxLen = lis[k];
}
}
// 若没有找到比array[i]小的值,array[i]构成长度为1的子序列
// 若找到则在最长递增序列末尾加上array[i]构成新最长子序列
lis[i] = maxLen + 1;
}
return max(lis);
}
static int max(int[] array) {
int maxValue = Integer.MIN_VALUE;
for (int i : array) {
if (i > maxValue) maxValue = i;
}
return maxValue;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] array = new int[n];
for (int i = 0; i < n; i++) {
array[i] = sc.nextInt();
}
int[] lis = new int[n];
System.out.println(solution(array, lis));
}
}
分析2
上述做法时间复杂度为
O(n2)
,原因在于max查找需要遍历一遍数组。如果能将查找过程优化就能得到更优的解法。应该注意一个事实,在找前
i
个元素的最长递增子序列时,应该尽量保证已经找到的子序列的末尾元素尽量小,这样第[2 1 4 3 1 5 6]
每一步分析如下:
i | tail | |
---|---|---|
0 | 2 | 2(初始值) |
1 | 1 | 1(替换2) |
2 | 4 | 1 4 |
3 | 3 | 1 3(替换4) |
4 | 1 | 1 3 |
5 | 5 | 1 3 5 |
6 | 6 | 1 3 5 6 |
可以看到数组 tail 就保存着最长递增子序列的所有值,且其中的每个元素总是在可选元素的最后一个。因为数组 tail 一定递增,如果将查找过程改为使用二分查找,则整个算法的时间复杂度为 O(nlog(n)) 。
代码2
import java.util.Scanner;
public class LIS {
static int solution2(int[] array, int[] tail) {
int n = array.length;
tail[0] = array[0];
int maxLen = 1;
for (int i = 1; i < n; i++) {
if (array[i] > tail[maxLen - 1]) {
tail[maxLen] = array[i];
++maxLen;
} else {// 线性查找
for (int k = 0; k < maxLen; k++) {
if (tail[k] >= array[i]) {
tail[k] = array[i];
break;
}
}
}
}
return maxLen;
}
static int solution3(int[] array, int[] tail) {
int n = array.length;
tail[0] = array[0];
int maxLen = 1;
for (int i = 1; i < n; i++) {
if (array[i] > tail[maxLen - 1]) {
tail[maxLen] = array[i];
++maxLen;
} else {// 二分查找
int pos = binarySearch(tail, maxLen, array[i]);
tail[pos] = array[i];
}
}
return maxLen;
}
static int binarySearch(int[] array, int limit, int value) {
int start = 0;
int end = limit - 1;
while (start <= end) {
int mid = start + (end - start) / 2;
if (array[mid] == value) {
return mid;
} else if (array[mid] < value) {
start = mid + 1;
} else {
end = mid - 1;
}
}
return start;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] array = new int[n];
for (int i = 0; i < n; i++) {
array[i] = sc.nextInt();
}
int[] tail = new int[n];
System.out.println(solution2(array, tail));
}
}