花了快1个小时,o(╯□╰)o 好弱。。。
定义T序列:为一段连续的序列且序列中的最大元素-最小元素<=1。
题意:给定一个由n个元素组成的序列a[],保证相邻元素之间差的绝对值不超过1。问你最长的T序列。
思路:定义dp[i]为以a[i]结尾的最长T序列。用一个结构体存储以下信息
le——序列长度,mx——序列最大元素,mi——序列最小元素。
考虑a[i]的状态,根据a[i] 与 dp[i-1].mi 和 dp[i-1].mx的关系,得到状态转移方程
一、放入元素a[i]依旧是T序列,限制条件dp[i-1].mi <= a[i] <= dp[i-1].mx || dp[i-1].mi == dp[i-1].mx
dp[i].le = dp[i-1].le + 1;
dp[i].mi = min(a[i], dp[i-1].mi);dp[i].mx = max(a[i], dp[i-1].mx);
二、放入元素a[i]不再是T序列
设v1 = min(a[i], a[j]) - 1 和 v2 = max(a[i], a[j]) + 1 即与a[i]矛盾的数(差的绝对值大于1)
我们的目的是找到离a[i]最近的v1的位置p1和离a[i]最近的v2的位置p2。
这样就会有
dp[i].le = i - max(p1, p2);
dp[i].mi = min(dp[i].mi, a[j]);
dp[i].mx = max(dp[i].mx, a[j]);
关键在于快速求出p1和p2。我的想法是先用num结构体记录元素val以及id值,排序后。在num.val里面找到最小 的k使得a[k] = v1,这样num[k + (中间v1元素的个数) - 1].id就是p1,同理求出p2。
至于求元素个数,用map就好了,边做dp边累加。 这样时间复杂度为O(nlogn)。
AC代码:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <stack>
#include <map>
#include <vector>
#define INF 0x3f3f3f
#define eps 1e-8
#define MAXN (1000000+10)
#define MAXM (100000)
#define Ri(a) scanf("%d", &a)
#define Rl(a) scanf("%lld", &a)
#define Rf(a) scanf("%lf", &a)
#define Rs(a) scanf("%s", a)
#define Pi(a) printf("%d\n", (a))
#define Pf(a) printf("%.2lf\n", (a))
#define Pl(a) printf("%lld\n", (a))
#define Ps(a) printf("%s\n", (a))
#define W(a) while(a--)
#define CLR(a, b) memset(a, (b), sizeof(a))
#define MOD 1000000007
#define LL long long
#define lson o<<1, l, mid
#define rson o<<1|1, mid+1, r
#define ll o<<1
#define rr o<<1|1
using namespace std;
int a[MAXN], cnt[MAXN];
struct Node{
int mx, mi, le;
} ;
Node dp[MAXN];
struct Rec{
int val, id;
};
Rec num[MAXN];
bool cmp(Rec a, Rec b)
{
if(a.val != b.val)
return a.val < b.val;
else
return a.id < b.id;
}
int Find(int l, int r, int v)
{
int ans;
while(r >= l)
{
int mid = (l + r) >> 1;
if(num[mid].val >= v)
{
ans = mid;
r = mid-1;
}
else if(num[mid].val < v)
l = mid+1;
}
return ans;
}
map<int, int> fp;
int main()
{
int n; Ri(n);
for(int i = 1; i <= n; i++)
{
Ri(a[i]);
num[i].val = a[i];
num[i].id = i;
}
sort(num+1, num+n+1, cmp);
int ans = 1;
dp[1].le = 1; dp[1].mx = dp[1].mi = a[1];
fp.clear(); fp[a[1]]++;
for(int i = 2; i <= n; i++)
{
dp[i].mx = dp[i].mi = a[i];
fp[a[i]]++;
int j = i - 1;
if(dp[j].mi <= a[i] && a[i] <= dp[j].mx || dp[j].mx == dp[j].mi)
{
dp[i].le = dp[j].le + 1;
dp[i].mi = min(dp[i].mi, dp[j].mi);
dp[i].mx = max(dp[i].mx, dp[j].mx);
}
else
{
int v1 = min(a[i], a[j]) - 1;
int v2 = max(a[i], a[j]) + 1;
int p1, p2;
if(!fp[v1])
p1 = 0;
else
p1 = Find(1, n, v1) + fp[v1] - 1;
if(!fp[v2])
p2 = 0;
else
p2 = Find(1, n, v2) + fp[v2] - 1;
dp[i].le = i - max(num[p1].id, num[p2].id);
dp[i].mi = min(dp[i].mi, a[j]);
dp[i].mx = max(dp[i].mx, a[j]);
}
ans = max(ans, dp[i].le);
}
Pi(ans);
return 0;
}