单调栈
常见模型:给定一个序列,求 序列当中的每一个元素左侧 离它最近,且比它小/大 的元素 或 右侧 离它最近,且比它小/大 的元素
举个例子,比如给定一个序列:3、4、2、7、5
我们要找到每一个元素左侧,且最近的比它小的数是什么,如果不存在返回 -1
第一个元素 3
左侧没有比它小的数,返回 -1
,第二个元素 4
左侧比它小的数且最近的数是 3
,因此返回 3
,后面的元素以此类推,如下图,第一行是原序列,第二行是返回的答案。
单调栈思考方式和双指针类似:先想想暴力做法是什么,之后挖掘出一些性质,将目光集中在比较少的状态中,从而起到将时间复杂度降低的效果。
暴力做法 时间复杂度为 O(n^2)
,当数据达到1e5
显然会超时
for(int i=0; i<n; ++i) //枚举原序列中的每一个数
for(int j=i-1; j>=0; --j) //从a[i]左侧第一个数a[i-1]开始往左枚举
if(a[i]>a[j]) //知道找到a[i]左侧第一个比a[i]小的数a[j]为止
{
printf("%d ", a[j]); //答案
break;
}
看看暴力做法中有什么可以挖掘的性质,
对于暴力做法,随着第一重循环 i
往右枚举,我们可以用一个栈存储 i
左侧所有元素,初始时栈为空,i
指针每往右移动一个位置就会往栈中新添一个元素,因此当枚举到 a[i]
时,其左侧所有元素:a[0]、a[1]、...、a[i-1]
都会被加入栈中,当我们要找答案的时候,我们从栈顶开始往下查找,找到第一个比 a[i]
小的数则 break
若要进行优化,我们显然应该将重点放在栈上,分析一下 栈中是否有些元素一定不会被作为答案输出,
举个例子,假设在栈中有 a[3] ≥ a[5]
,那么 a[3]
一定不会在后续中作为答案输出,
因为:a[5]
在 a[3]
的右侧,且 a[5] ≤ a[3]
,当我们找答案时,如果 a[3]
是目标值,则意味着 a[3] < a[i]
,又由于 a[5] ≤ a[3]
,因此 a[5]
也小于 a[i]
,我们显然要查找离 a[i]
最近且比 a[i]
小的元素,则 a[3]
一定不会被用作答案输出,那么我们一定可以换成 a[5]
作为更优的答案
单调栈核心思想:
对于更一般的情况,如果栈中存在这样的逆序关系:a[x] ≥ a[y] (x < y)
,那么 a[x]
可以从栈中删去,最终栈中剩下的序列一定是个严格单调序列了
如果我们想在这个严格单调的栈(设为 stk[]
)中查找答案,从栈顶 stk[tt]
开始查找,如果 stk[tt] ≥ a[i]
,a[i]
是在 stk[tt]
更右侧,显然答案更优,那么 栈顶 stk[tt]
必定不会在后续中被当成答案输出,因此可以直接删掉栈顶元素,这样循环进行,直到一个栈顶元素 stk[tt]
小于 a[i]
为止,此时的 栈顶元素 stk[tt]
就是 a[i]
左侧且离它最近且比它小的元素(答案),之后把 a[i]
压入栈中。
时间复杂度:
O(n)O(n)O(n)
代码片段:
for(int i=0; i<n; ++i)
{
int x;
scanf("%d", &x); //读入 n 个元素
//当栈不为空且栈顶元素 stk[tt] 大于等于当前数时栈顶元素一定不会被再用到,则删除栈顶元素 tt--
while(tt && stk[tt]>=x) --tt;
if(tt) printf("%d ", stk[tt]); //上述操作删除栈顶操作完成后,如果栈不为空,那么栈顶则为我们要找的答案,输出即可
else printf("-1 "); //如栈为空,说明x左侧没有任何一个元素比它小,输出 -1 即可
stk[++tt]=x; //最后记得将 x 压入栈中
}
例题:AcWing 830. 单调栈
代码:
手打数组模拟栈
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int stk[N], n, tt;
int main()
{
cin>>n;
for(int i=0; i<n; ++i)
{
int x;
scanf("%d", &x);
while(tt && stk[tt]>=x) --tt;
if(tt) printf("%d ", stk[tt]);
else printf("-1 ");
stk[++tt]=x;
}
return 0;
}
STL栈
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int n;
stack<int> stk;
int main()
{
cin>>n;
for(int i=0; i<n; ++i)
{
int x;
scanf("%d", &x);
while(stk.size() && stk.top()>=x) stk.pop();
if(stk.size()) printf("%d ", stk.top());
else printf("-1 ");
stk.push(x);
}
return 0;
}