题目链接
dfs与线段树的结合:
我们记录对于进入每个节点和回溯每个节点的时间戳:
void dfs(int cur)
{
cnt++;
l[cur] = cnt;
for (int i = he[cur]; i; i = ne[i])
{
int y = ver[i];
dfs(y);
}
r[cur] = cnt;
return;
}
可以看出:对于cur节点的子树,它的进入、离开的时间戳都在cur的进入、离开的时间戳的这个区间内。换句话说。我们可以理解为子节点的时间戳是cur节点的子区间。我们以此关系建立线段树。每次修改节点,就是对该节点进入、离开的时间戳之内的所有区间(子节点)进行修改。即区间修改。查询可以理解为对l[cur]的单点查询。
下面是ac代码:
#include <iostream>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <algorithm>
#include <queue>
#include <cstdio>
#include <cstdlib>
#define ll long long
using namespace std;
const int N = 50005;
int he[N], ne[N], ver[N];
int l[N], r[N];
bool vis[N], di[N];
int tot, cnt;
int mx = 0;
struct Node
{
int l, r;
int sum;
int add;
}tr[N<<2];
void init()
{
tot = 1;
memset(he, 0, sizeof(he));
memset(di, 0, sizeof(di));
cnt = 0;
}
void build(int p, int l, int r)
{
mx = max(mx, p);
tr[p].l = l; tr[p].r = r;
tr[p].sum = -1;
tr[p].add = -1;
if (l == r) return;
int mid = (l + r)>> 1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
}
void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
void dfs(int cur)
{
cnt++;
l[cur] = cnt;
for (int i = he[cur]; i; i = ne[i])
{
int y = ver[i];
dfs(y);
}
r[cur] = cnt;
return;
}
void spread(int p)
{
if (tr[p].add == -1) return;
int l = p<<1, r = p<<1|1;
tr[l].sum = tr[p].add;
tr[r].sum = tr[p].add;
tr[l].add = tr[r].add = tr[p].add;
tr[p].add = -1;
}
void change(int p, int l, int r, int v)
{
if (tr[p].l >= l && tr[p].r <= r)
{
tr[p].sum = v;
tr[p].add = v;
return;
}
spread(p);
int mid = (tr[p].l + tr[p].r) >> 1;
if (l <= mid) change(p<<1, l, r, v);
if(r > mid) change(p<<1|1, l, r, v);
}
int ask(int p, int k)
{
if (tr[p].l == tr[p].r) return tr[p].sum;
spread(p);
int mid = (tr[p].l + tr[p].r) >> 1;
if (k <= mid) return ask(p<<1, k);
else return ask(p<<1|1, k);
}
void print()
{
for (int i = 1; i <= mx; i++)
cout << tr[i].l << " " <<tr[i].r << " " << tr[i].sum <<" " << tr[i].add << endl;
cout << "--------------------" << endl;
}
int main()
{
int t;
cin >> t;
int t0 = 1;
while(t--)
{
init();
int n;
scanf("%d", &n);
build(1, 1, n);
// print();
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &y, &x);
add(x, y);
di[y] = 1;
}
for (int i = 1; i <= n; i++)
{
if (!di[i])
{
dfs(i);
break;
}
}
int q;
scanf("%d", &q);
printf("Case #%d:\n", t0++);
while(q--)
{
char op[5];
scanf("%s", op);
if(op[0] == 'C')
{
int x;
scanf("%d", &x);
printf("%d\n", ask(1, l[x]));
// print();
}
else
{
int x, y;
scanf("%d%d", &x, &y);
change(1, l[x], r[x], y);
// print();
}
}
}
return 0;
}