树剖加线段树好题
/**********
明天你是否会想起
昨天未调完的题
明天你是否还惦记
考场写挂的暴力
**********/
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=1e18;//按题目要求随时修改.
void read(int &x)
{
int f=1;x=0;char s=getchar();
while(s<'0'||s>'9'){if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9'){x=x*10%mod+(s-'0')%mod;s=getchar();}
x=x%mod*f;
}
const int N=4e5+10;
struct xds
{
int l,r;
// int lazy;
int sum;
int lazy;
}xd[4*N];
vector<int>f[N];
int n,m;
int h[N];
int sum[N],fa[N],dep[N],son[N],top[N];
int new_point[N],pre[N];
int cnt;
void dfs(int father,int cur)
{
sum[cur]=1;
for(auto v:f[cur])
{
if(v==father) continue;
fa[v]=cur;
dep[v]=dep[cur]+1;
dfs(cur,v);
sum[cur]+=sum[v];
if(sum[v]>sum[son[cur]]) son[cur]=v;
}
}
void dfs2(int fcur,int cur)
{
top[cur]=fcur;
cnt++;
new_point[cur]=cnt;
pre[cnt]=cur;
if(son[cur]) dfs2(fcur,son[cur]);
for(auto v:f[cur])
{
if(v==fa[cur]||v==son[cur]) continue;
dfs2(v,v);
}
}
void pushup(int u)
{
xd[u].sum=xd[u<<1].sum|xd[u<<1|1].sum;
}
void build(int u,int l,int r)
{
if(l==r) xd[u]={l,r,1ll<<h[pre[l]],0};
else
{
xd[u]={l,r,0,0};
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
}
void pushdown(int u)
{
if(xd[u].lazy)
{
int c=xd[u].lazy;
xd[u<<1].sum=xd[u<<1|1].sum=1ll<<c;
xd[u<<1].lazy=xd[u<<1|1].lazy=c;
xd[u].lazy=0;
}
}
void motify(int u,int l,int r,int c)
{
if(xd[u].l>=l&&xd[u].r<=r)
{
xd[u].sum=1ll<<c;
xd[u].lazy=c;
}
else
{
pushdown(u);
int mid=xd[u].l+xd[u].r>>1;
if(l<=mid) motify(u<<1,l,r,c);
if(r>mid) motify(u<<1|1,l,r,c);
pushup(u);
}
}
int query(int u,int l,int r)
{
if(xd[u].l>=l&&xd[u].r<=r) return xd[u].sum;
else
{
pushdown(u);
int res=0;
int mid=xd[u].l+xd[u].r>>1;
if(l<=mid) res=query(u<<1,l,r);
if(r>mid) res|=query(u<<1|1,l,r);
return res;
}
}
void solve()
{
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>h[i];
for(int i=1;i<=n-1;i++)
{
int u, v;
cin >> u >> v;
f[u].push_back(v);
f[v].push_back(u);
}
dfs(0,1);
dfs2(1,1);
build(1,1,n);
while(m--)
{
int op,u,color;
cin>>op;
if(op==1)
{
cin>>u>>color;
motify(1,new_point[u],new_point[u]+sum[u]-1,color);
}
else
{
cin>>u;
int ans=query(1,new_point[u],new_point[u]+sum[u]-1);
int k=0;
while(ans)
{
if(ans&1) k++;
ans>>=1;
}
cout<<k<<'\n';
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr),cout.tie(nullptr);
int t;
t=1;
while(t--){solve();}
return 0;
}