引入
求
d
k
=
∑
i
o
p
t
j
=
k
a
i
b
j
d_k = \sum_{i \ opt\ j=k}a_ib_j
dk=i opt j=k∑aibj
其中
o
p
t
opt
opt是某种位运算。不准暴力
分析
我们已经知道可以用FFT来加速一个多项式卷积的运算。即将一个多项式转化为点值表示法,通过增加转换部分的用时来减少计算的用时,从而优化整个运算过程。
(图不是我画的)
于是我们想利用类似的思路,将数列 A A A, B B B通过某种方式转化成另一种形式,从而加快计算的速度。
FWT与IFWT
按照上面的思路,我们想将数列
A
A
A,
B
B
B转化为数列
A
′
A'
A′,
B
′
B'
B′,然后算得
d
k
′
=
∑
i
=
1
n
a
k
′
b
k
′
d'_k=\sum_{i=1}^na'_kb'_k
dk′=i=1∑nak′bk′
再将
D
′
D'
D′转化为
D
D
D即可。
观察我们想要计算的结果可知,变换
F
W
T
(
A
)
=
A
′
FWT(A)=A'
FWT(A)=A′
的过程一定得是个关于
A
A
A中各个元素的线性变换(如果有交叉项的话必定得不到
C
C
C中的元素)。也就是说,上述的变换本质上就是一个矩阵
C
C
C。再结合我们的需求,我们所要做的就是找到一个可逆矩阵
C
C
C。
那么这个矩阵有什么特性呢?令
F
W
T
(
A
)
⋅
F
W
T
(
B
)
=
F
W
T
(
D
)
FWT(A) \cdot FWT(B)=FWT(D)
FWT(A)⋅FWT(B)=FWT(D)
即需要满足对任意的
i
i
i
F
W
T
(
A
)
i
⋅
F
W
T
(
B
)
i
=
F
W
T
(
D
)
i
FWT(A)_i \cdot FWT(B)_i = FWT(D)_i
FWT(A)i⋅FWT(B)i=FWT(D)i
也就是
∑
k
=
1
n
∑
j
=
1
n
c
i
j
c
i
k
a
k
b
j
=
∑
k
=
1
n
c
i
k
d
k
\sum_{k=1}^n\sum_{j=1}^nc_{ij}c_{ik}a_kb_j=\sum_{k=1}^nc_{ik}d_k
k=1∑nj=1∑ncijcikakbj=k=1∑ncikdk
结合题意
∑
k
=
1
n
∑
j
=
1
n
c
i
j
c
i
k
a
k
b
j
=
∑
k
=
1
n
c
i
k
∑
m
o
p
t
l
=
k
a
m
b
l
\sum_{k=1}^n\sum_{j=1}^nc_{ij}c_{ik}a_kb_j=\sum_{k=1}^nc_{ik}\sum_{m \ opt\ l=k}a_mb_l
k=1∑nj=1∑ncijcikakbj=k=1∑ncikm opt l=k∑ambl
右式改为枚举
m
,
l
m,l
m,l
∑
k
=
1
n
∑
j
=
1
n
c
i
j
c
i
k
a
j
b
j
=
∑
k
=
1
n
∑
j
=
1
n
c
i
j
o
p
t
k
a
k
b
j
\sum_{k=1}^n\sum_{j=1}^nc_{ij}c_{ik}a_jb_j=\sum_{k=1}^n\sum_{j=1}^nc_{i\ j\ opt\ k}a_kb_j
k=1∑nj=1∑ncijcikajbj=k=1∑nj=1∑nci j opt kakbj
为使上式恒成立,可以规定
c
i
j
c
i
k
=
c
i
j
o
p
t
k
c_{ij}c_{ik}=c_{i\ \ j \ opt \ k}
cijcik=ci j opt k
于是我们就可以根据这个式子,按照实际的
o
p
t
opt
opt来构造这个矩阵了。
下面考虑有了矩阵
C
C
C以后如何计算
F
W
T
(
A
)
i
=
∑
k
=
1
n
c
i
k
a
k
FWT(A)_i=\sum_{k=1}^nc_{ik}a_k
FWT(A)i=k=1∑ncikak
为方便叙述,令
n
=
2
m
n=2^m
n=2m。我们将
n
n
n拆成最高位为0与最高位为1两个部分。
F
W
T
(
A
)
i
=
∑
k
=
1
n
2
−
1
c
i
k
a
k
+
∑
k
=
n
2
n
c
i
k
a
k
FWT(A)_i=\sum_{k=1}^{\frac{n}{2}-1}c_{ik}a_k+\sum_{k=\frac{n}{2}}^{n}c_{ik}a_k
FWT(A)i=k=1∑2n−1cikak+k=2n∑ncikak
发现这个形式跟归并排序有点像,于是就想着分治,也就需要将两项进行变形为相似的形式。
我们对这个矩阵进一步的把玩发现:若令
x
i
x_i
xi为
x
x
x的第
i
i
i位,则
c
i
j
=
∏
k
=
1
m
c
i
k
j
k
c_{ij}=\prod_{k=1}^mc_{i_kj_k}
cij=k=1∏mcikjk
所以我们可以把两项的矩阵系数拆出
k
k
k最高位的部分。令剩下的为
k
′
k'
k′,则
F
W
T
(
A
)
i
=
c
i
1
0
∑
k
=
1
n
2
−
1
c
i
′
k
′
a
k
+
c
i
1
1
∑
k
=
n
2
n
c
i
′
k
′
a
k
FWT(A)_i=c_{i_10}\sum_{k=1}^{\frac{n}{2}-1}c_{i'k'}a_k+c_{i_11}\sum_{k=\frac{n}{2}}^{n}c_{i'k'}a_k
FWT(A)i=ci10k=1∑2n−1ci′k′ak+ci11k=2n∑nci′k′ak
那么可以发现:我们无需构造原来的大小为
n
×
n
n × n
n×n的矩阵,而仅需构造一个
2
×
2
2×2
2×2的矩阵,就可以递归处理了
F
W
T
(
A
)
i
=
c
00
F
W
T
(
A
l
f
t
)
+
c
01
F
W
T
(
A
r
g
t
)
i
<
n
2
F
W
T
(
A
)
i
+
n
2
=
c
10
F
W
T
(
A
l
f
t
)
+
c
11
F
W
T
(
A
r
g
t
)
i
≥
n
2
FWT(A)_i=c_{00}FWT(A_{lft})+c_{01}FWT(A_{rgt}) \ \ \ i < \frac{n}{2} \\ FWT(A)_{i+\frac{n}{2}}=c_{10}FWT(A_{lft})+c_{11}FWT(A_{rgt}) \ \ \ i \geq \frac{n}{2}
FWT(A)i=c00FWT(Alft)+c01FWT(Argt) i<2nFWT(A)i+2n=c10FWT(Alft)+c11FWT(Argt) i≥2n
同理,由于构造出来的矩阵可逆,所以FWT的逆变换可以直接利用
C
C
C的逆矩阵做一次FWT即可。
基础位运算对应转移矩阵
与
C = ( 1 1 0 1 ) C=\left(\begin{matrix} 1 & 1\\ 0 & 1 \end{matrix}\right) C=(1011)
或
C = ( 1 1 0 1 ) C=\left(\begin{matrix} 1 & 1\\ 0 & 1 \end{matrix}\right) C=(1011)
异或
C = ( 1 1 1 − 1 ) C=\left(\begin{matrix} 1 & 1\\ 1 & -1 \end{matrix}\right) C=(111−1)
例题与代码
题面
考虑DP。设
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示以
i
i
i为根的子树中异或和为
j
j
j的数目,则
f
[
i
]
[
j
⊕
k
]
=
∑
j
=
0
m
∑
k
=
0
m
f
[
i
]
[
j
]
⋅
f
[
s
o
n
[
i
]
]
[
k
]
f[i][j \oplus k]=\sum_{j=0}^m\sum_{k=0}^mf[i][j]\cdot f[son[i]][k]
f[i][j⊕k]=j=0∑mk=0∑mf[i][j]⋅f[son[i]][k]
利用FWT优化即可
#include<bits/stdc++.h>
#define reg register
#define ll long long
using namespace std;
const int mn = 1005, mod = 1e9+7;
const int inv2 = 500000004;
vector<int> g[mn];
ll f[mn][1030], h[1030];
ll cxor[2][2] = {{1,1},{1,mod-1}}, icxor[2][2] = {{inv2,inv2},{inv2,mod-inv2}};
int a[mn], n, m;
inline void fwt(ll *a, ll c[2][2])
{
for(reg int len = 1; len < m; len <<= 1)
for(reg int st = 0; st < m; st += len + len)
for(reg int i = st; i < st + len; i++)
{
ll tmp = a[i];
a[i] = (c[0][0] * a[i] + c[0][1] * a[i + len]) % mod;
a[i + len] = (c[1][0] * tmp + c[1][1] * a[i + len]) % mod;
}
}
void dp(int s, int fa)
{
int siz = g[s].size();
f[s][a[s]] = 1;
for(reg int i = 0; i < siz; i++)
{
int t = g[s][i];
if(t != fa)
{
dp(t, s);
fwt(f[s], cxor), fwt(f[t], cxor);
for(int j = 0; j < m; j++)
h[j] = f[s][j] * f[t][j] % mod;
fwt(h, icxor), fwt(f[t], icxor), fwt(f[s], icxor);
for(int j = 0; j < m; j++)
f[s][j] += h[j], f[s][j] %= mod;
}
}
}
inline int getint()
{
int ret = 0; char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') ret = ret * 10 + c - '0', c = getchar();
return ret;
}
int main()
{
int T, x, y;
T = getint();
while(T--)
{
n = getint(), m = getint();
for(reg int i = 1; i <= n; i++)
a[i] = getint(), g[i].clear();
for(reg int i = 1; i < n; i++)
{
x = getint(), y = getint();
g[x].push_back(y), g[y].push_back(x);
}
for(reg int i = 1; i <= n; i++)
for(reg int j = 0; j < m; j++)
f[i][j] = 0;
dp(1, 0);
for(reg int i = 0; i < m; i++)
{
ll ans = 0;
for(reg int j = 1; j <= n; j++)
ans += f[j][i], ans %= mod;
printf("%I64d", ans);
if(i != m - 1) putchar(' ');
else puts("");
}
}
}