Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Solution
树链剖分傻逼题。
线段树搞一搞,每个区间维护一个最左边、最右边的颜色、颜色段的个数,合并的时候判断一下左区间的右端点是否等于右区间的左端点即可。
Code
/************************************************
* Au: Hany01
* Date: Sep 5th, 2018
* Prob: BZOJ2243 SDOI2011 染色
* Email: hany01dxx@gmail.com & hany01@foxmail.com
* Inst: Yali High School
************************************************/
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef long double LD;
typedef pair<int, int> PII;
#define rep(i, j) for (register int i = 0, i##_end_ = (j); i < i##_end_; ++ i)
#define For(i, j, k) for (register int i = (j), i##_end_ = (k); i <= i##_end_; ++ i)
#define Fordown(i, j, k) for (register int i = (j), i##_end_ = (k); i >= i##_end_; -- i)
#define Set(a, b) memset(a, b, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define x first
#define y second
#define pb(a) push_back(a)
#define mp(a, b) make_pair(a, b)
#define SZ(a) ((int)(a).size())
#define ALL(a) a.begin(), a.end()
#define INF (0x3f3f3f3f)
#define INF1 (2139062143)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define y1 wozenmezhemecaia
template <typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
inline int read() {
static int _, __; static char c_;
for (_ = 0, __ = 1, c_ = getchar(); c_ < '0' || c_ > '9'; c_ = getchar()) if (c_ == '-') __ = -1;
for ( ; c_ >= '0' && c_ <= '9'; c_ = getchar()) _ = (_ << 1) + (_ << 3) + (c_ ^ 48);
return _ * __;
}
const int maxn = 1e5 + 5;
int n, m, beg[maxn], v[maxn << 1], nex[maxn << 1], dep[maxn], fa[maxn], top[maxn], son[maxn], sz[maxn], dfn[maxn], clk, id[maxn], co[maxn], e;
inline void add(int uu, int vv) { v[++ e] = vv, nex[e] = beg[uu], beg[uu] = e; }
void DFS1(int u) {
dep[u] = dep[fa[u]] + 1, sz[u] = 1;
for (register int i = beg[u]; i; i = nex[i])
if (v[i] != fa[u]) {
fa[v[i]] = u, DFS1(v[i]), sz[u] += sz[v[i]];
if (sz[son[u]] < sz[v[i]]) son[u] = v[i];
}
}
void DFS2(int u) {
id[dfn[u] = ++ clk] = u;
if (son[u]) top[son[u]] = top[u], DFS2(son[u]);
for (register int i = beg[u]; i; i = nex[i])
if (v[i] != fa[u] && v[i] != son[u]) top[v[i]] = v[i], DFS2(v[i]);
}
struct Node { int l, r, sum; };
inline Node merge(Node A, Node B) {
if (!B.l) return A;
if (!A.l) return B;
return (Node){A.l, B.r, A.sum + B.sum - (A.r == B.l)};
}
struct SegTree {
int setv[maxn << 2];
Node tr[maxn << 2];
#define mid ((l + r) >> 1)
#define lc (t << 1)
#define rc (lc | 1)
inline void pushdown(int t) {
if (setv[t] != -1) setv[lc] = setv[rc] = tr[lc].l = tr[lc].r = tr[rc].l = tr[rc].r = setv[t], tr[lc].sum = tr[rc].sum = 1, setv[t] = -1;
}
void build(int t, int l, int r) {
if (l == r) { tr[t] = (Node){co[id[l]], co[id[l]], 1}, setv[t] = -1; }
else build(lc, l, mid), build(rc, mid + 1, r), tr[t] = merge(tr[lc], tr[rc]), setv[t] = -1;
}
void update(int t, int l, int r, int x, int y, int c) {
if (x <= l && r <= y) { tr[t].sum = 1, tr[t].l = tr[t].r = setv[t] = c; return; }
pushdown(t);
if (x <= mid) update(lc, l, mid, x, y, c);
if (y > mid) update(rc, mid + 1, r, x, y, c);
tr[t] = merge(tr[lc], tr[rc]);
}
Node query(int t, int l, int r, int x, int y) {
if (x <= l && r <= y) return tr[t];
pushdown(t);
if (y <= mid) return query(lc, l, mid, x, y);
if (x > mid) return query(rc, mid + 1, r, x, y);
return merge(query(lc, l, mid, x, y), query(rc, mid + 1, r, x, y));
}
}ST;
inline void modify(int u, int v, int c) {
int fu = top[u], fv = top[v];
while (fu != fv) {
if (dep[fu] < dep[fv]) swap(fu, fv), swap(u, v);
ST.update(1, 1, n, dfn[top[u]], dfn[u], c), u = fa[top[u]], fu = top[u];
}
if (dep[u] > dep[v]) swap(u, v);
ST.update(1, 1, n, dfn[u], dfn[v], c);
}
inline int query(int u, int v) {
int fu = top[u], fv = top[v];
Node Ansu = (Node){0, 0, 0}, Ansv = (Node){0, 0, 0};
while (fu != fv) {
if (dep[fu] < dep[fv]) swap(fu, fv), swap(u, v), swap(Ansu, Ansv);
Ansu = merge(ST.query(1, 1, n, dfn[top[u]], dfn[u]), Ansu), u = fa[top[u]], fu = top[u];
}
if (dep[u] < dep[v]) swap(u, v), swap(Ansu, Ansv);
Ansu = merge(ST.query(1, 1, n, dfn[v], dfn[u]), Ansu), swap(Ansu.l, Ansu.r);
return merge(Ansu, Ansv).sum;
}
int main()
{
#ifdef hany01
freopen("bzoj2243.in", "r", stdin);
freopen("bzoj2243.out", "w", stdout);
#endif
static int u, v;
static char ty[3];
n = read(), m = read();
For(i, 1, n) co[i] = read();
For(i, 2, n) u = read(), v = read(), add(u, v), add(v, u);
DFS1(1), DFS2(top[1] = 1), ST.build(1, 1, n);
while (m --) {
scanf("%s", ty);
if (ty[0] == 'C') u = read(), v = read(), modify(u, v, read());
else u = read(), printf("%d\n", query(u, read()));
}
return 0;
}