写下自己出错过的地方:
1, 读入$n-1$条边, 别写成REP(i,1,n)
2, 分治时, solve(rt)别写成solve(y)
3,
练习1 CF 161D
大意:给定树, 求长为k的链的个数
该题算点分入门了, 比较简单, 可以拿来练手. 当然这题$k$的范围较小, $O(nk)$的暴力dp也是能过的
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#define pb push_back
#define REP(i,a,n) for(int i=a;i<=n;++i)
using namespace std;
typedef long long ll;
const int N = 4e5+10, INF = 0x3f3f3f3f;
int n, k, sum, rt;
int maxp[N], sz[N], vis[N];
int cnt[N];
ll ans;
vector<int> g[N];
void getrt(int x, int fa) {
sz[x]=1, maxp[x]=0;
for (int y:g[x]) if (!vis[y]&&y!=fa) {
getrt(y,x),sz[x]+=sz[y];
maxp[x]=max(maxp[x],sz[y]);
}
maxp[x]=max(maxp[x],sum-sz[x]);
if (maxp[rt]>maxp[x]) rt=x;
}
void dfs(int x, int fa, int d, int v) {
if (d>k) return;
cnt[d] += v;
for (int y:g[x]) if (!vis[y]&&y!=fa) dfs(y,x,d+1,v);
}
void dfs2(int x, int fa, int d) {
if (d>k) return;
ans += cnt[k-d];
for (int y:g[x]) if (!vis[y]&&y!=fa) dfs2(y,x,d+1);
}
void solve(int x) {
vis[x]=1;
dfs(x,0,0,1);
ans += cnt[k];
for (int y:g[x]) if (!vis[y]) {
dfs(y,x,1,-1);
dfs2(y,x,1);
dfs(y,x,1,1);
}
dfs(x,0,0,-1);
for (int y:g[x]) if (!vis[y]) {
maxp[rt=0]=n,sum=sz[y];
getrt(y,0), solve(rt);
}
}
int main() {
scanf("%d%d", &n, &k);
REP(i,2,n) {
int u, v;
scanf("%d%d", &u, &v);
g[u].pb(v),g[v].pb(u);
}
sum=maxp[0]=n, getrt(1,0),solve(rt);
printf("%lld\n", ans>>1);
}
练习2 POJ 1741
大意: 求树上有多少长度不超过k的链
跟上题一样, 树状数组维护一个前缀即可
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#define REP(i,a,n) for(int i=a;i<=n;++i)
#define pb push_back
using namespace std;
typedef long long ll;
const int N = 4e5+10, INF = 0x3f3f3f3f;
int n, k, rt, sum;
struct _ {int to,w;};
vector<_> g[N];
ll ans;
int sz[N], mx[N], vis[N];
int c[N];
void add(int x, int v) {
for (++x; x<=k+1; x+=x&-x) c[x]+=v;
}
int qry(int x) {
int r = 0;
for (++x; x; x^=x&-x) r+=c[x];
return r;
}
void getrt(int x, int fa) {
mx[x]=0,sz[x]=1;
for (_ e:g[x]) if (!vis[e.to]&&e.to!=fa) {
int y = e.to;
getrt(y,x),sz[x]+=sz[y];
mx[x]=max(mx[x],sz[y]);
}
mx[x]=max(mx[x],sum-sz[x]);
if (mx[x]<mx[rt]) rt=x;
}
void dfs(int x, int fa, int w, int v) {
if (w>k) return;
add(w,v);
for (_ e:g[x]) if (!vis[e.to]&&e.to!=fa) {
int y = e.to;
dfs(y,x,w+e.w,v);
}
}
void dfs2(int x, int fa, int w) {
if (w>k) return;
ans += qry(k-w);
for (_ e:g[x]) if (!vis[e.to]&&e.to!=fa) {
int y = e.to;
dfs2(y,x,w+e.w);
}
}
void solve(int x) {
vis[x]=1;
dfs(x,0,0,1);
ans += qry(k)-qry(0);
for (_ e:g[x]) if (!vis[e.to]) {
int y = e.to;
dfs(y,x,e.w,-1);
dfs2(y,x,e.w);
dfs(y,x,e.w,1);
}
dfs(x,0,0,-1);
for (_ e:g[x]) if (!vis[e.to]) {
int y = e.to;
mx[rt=0]=n,sum=sz[y];
getrt(y,0),solve(rt);
}
}
int main() {
for (; scanf("%d%d", &n, &k),n||k; ) {
REP(i,1,n) g[i].clear();
REP(i,2,n) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
g[u].pb({v,w});
g[v].pb({u,w});
}
sum=mx[0]=n,getrt(1,0),solve(rt);
printf("%lld\n", ans>>1);
}
}
练习3 CF 914E
大意: 给定树,每个点上有个字母a-t之一, 对于每个点, 求有多少经过该点的链使得: 链上字母的某一个排列是回文的
注意到回文等价于出现奇数次的字母最多只有一个, 状压一下即可
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <set>
#include <vector>
#define REP(i,a,n) for(int i=a;i<=n;++i)
#define pb push_back
using namespace std;
typedef long long ll;
const int N = 2e5+10, INF = 0x3f3f3f3f;
int n, sum, rt;
int a[N], vis[N], sz[N], maxp[N];
char s[N];
vector<int> g[N];
ll ans[N], cnt[1<<20];
void getrt(int x, int fa) {
sz[x]=1, maxp[x]=0;
for (int y:g[x]) if (!vis[y]&&y!=fa) {
getrt(y,x),sz[x]+=sz[y];
maxp[x]=max(maxp[x],sz[y]);
}
maxp[x]=max(maxp[x],sum-sz[x]);
if (maxp[rt]>maxp[x]) rt=x;
}
void dfs(int x, int fa, int s, int v) {
s ^= a[x], cnt[s] += v;
for (int y:g[x]) if (!vis[y]&&y!=fa) dfs(y,x,s,v);
}
ll calc(int x, int fa, int s) {
s ^= a[x];
ll r = cnt[s];
REP(i,0,19) r += cnt[s^1<<i];
for (int y:g[x]) if (!vis[y]&&y!=fa) r+=calc(y,x,s);
ans[x] += r;
return r;
}
void solve(int x) {
vis[x] = 1;
dfs(x,0,0,1);
ll r = cnt[0];
REP(i,0,19) r += cnt[1<<i];
for (int y:g[x]) if (!vis[y]) {
dfs(y,x,a[x],-1);
r += calc(y,x,0);
dfs(y,x,a[x],1);
}
//单独一个字母s[x]的链只计算了一次
//其余链均正反记录两次
ans[x] += (r+1)/2;
dfs(x,0,0,-1);
for (int y:g[x]) if (!vis[y]) {
maxp[rt=0]=n,sum=sz[y];
getrt(y,0), solve(rt);
}
}
int main() {
scanf("%d", &n);
REP(i,2,n) {
int u, v;
scanf("%d%d", &u, &v);
g[u].pb(v),g[v].pb(u);
}
scanf("%s", s+1);
REP(i,1,n) a[i]=1<<(s[i]-'a');
sum=maxp[rt]=n, getrt(1,0), solve(rt);
REP(i,1,n) printf("%lld ", ans[i]);
puts("");
}