题目链接:点击打开链接
思路:类似于最长上升子序列, 我们很容易得到一个n^2的算法, 但是时间复杂度无法承受。
可以发现, |a[i]-a[j]| >= d相当于对于每个j, 找到一个i < j && (a[i] <= a[j]-d || a[i] >= a[j]+d)中最大的dp[i]。 我们将数字大小离散化之后做线段树下标, 然后维护一个区间最大值就行了。 时间复杂度O(nlogn)
细节参见代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 10;
ll T,n,d,pos[maxn], a[maxn], b[maxn],p[maxn] ;
int dp[maxn];
struct node {
int id, v;
node(int id=0, int v=0):id(id), v(v) {}
}maxv[maxn<<2];
void pushup(int o) {
maxv[o] = maxv[o<<1];
if(maxv[o].v < maxv[o<<1|1].v) {
maxv[o] = maxv[o<<1|1];
}
}
void build(int l, int r, int o) {
int m = (l + r) >> 1;
maxv[o] = 0;
if(l == r) return ;
build(l, m, o<<1);
build(m+1, r, o<<1|1);
}
void update(int L, int R, int id, int v, int l, int r, int o) {
int m = (l + r) >> 1;
if(L <= l && r <= R) {
maxv[o] = node(id, v);
return ;
}
if(L <= m) update(L, R,id, v, l, m, o<<1);
if(m < R) update(L, R,id, v, m+1, r, o<<1|1);
pushup(o);
}
node query(int L, int R, int l, int r, int o) {
int m = (l + r) >> 1;
if(L <= l && r <= R) {
return maxv[o];
}
node ans = node(0, 0);
if(L <= m) {
node cur = query(L, R, l, m, o<<1);
if(ans.v < cur.v) {
ans = cur;
}
}
if(m < R) {
node cur = query(L, R, m+1, r, o<<1|1);
if(ans.v < cur.v) {
ans = cur;
}
}
return ans;
}
void path(int root) {
vector<int> q;
while(root != -1) {
q.push_back(root);
root = p[root];
}
for(int i = q.size()-1; i >= 0; i--) printf("%d%c", q[i], i == 0 ? '\n' : ' ');
}
int main() {
scanf("%I64d%I64d",&n,&d);
for(int i = 1; i <= n; i++) {
scanf("%I64d", &a[i]);
b[i-1] = a[i];
}
sort(b, b+n);
int len = unique(b, b+n) - b;
build(1, len, 1);
int res = 0;
memset(p, -1, sizeof(p));
for(int i = 1; i <= n; i++) {
int pos = lower_bound(b, b+len, a[i]) - b + 1;
int& ans = dp[i] = 1;
int p1 = lower_bound(b, b+len, a[i]-d) - b;
if(b[p1] > a[i]-d) ;
else p1++;
int p2 = lower_bound(b, b+len, a[i]+d) - b + 1;
if(p1 >= 1) {
node cur = query(1, p1, 1, len, 1);
cur.v++;
if(ans < cur.v) {
ans = cur.v;
p[i] = cur.id;
}
}
if(p2 <= len) {
node cur = query(p2, len, 1, len, 1);
cur.v++;
if(ans < cur.v) {
ans = cur.v;
p[i] = cur.id;
}
}
update(pos, pos, i, ans, 1, len, 1);
res = max(res, ans);
}
printf("%d\n", res);
for(int i = n; i >= 1; i--) {
if(dp[i] == res) {
path(i);
break;
}
}
return 0;
}