题意
给你一个n*k的非负整数网格P以及一个系数数组A。
要求每行选一个格子,每选一个第j列的格子,这种选择方案的权值就乘上
A
[
j
]
A[j]
A[j]。
问,对于
[
0
,
2
m
)
[0,2^m)
[0,2m)中所有数x,选的所有格子的P的异或=x的权值和是多少。
n<=1e6
分析
据说是CF H题的加强。
首先有简单FWT做法。即求所有行的多项式
F
i
F_i
Fi,然后
A
n
s
=
∏
F
i
Ans=\prod F_i
Ans=∏Fi,用FWT加速。
O
(
n
m
2
m
)
O(nm2^m)
O(nm2m)
接下来优化该做法:
首先将P中所有元素异或上该行第0个(从0标号),最终答案再对应换位。
注意到每个多项式F的系数都是若干个
a
i
a_i
ai之和,并且
F
W
T
(
A
)
[
i
]
=
∑
j
(
−
1
)
b
i
t
c
o
u
n
t
(
i
&
j
)
A
j
FWT(A)[i]=\sum_j (-1)^{bitcount(i\&j)}A_j
FWT(A)[i]=j∑(−1)bitcount(i&j)Aj
考虑计算
F
W
T
(
A
n
s
)
FWT(Ans)
FWT(Ans)中的每一个位置的每种和(每一个A_i的系数是正或负1)的个数。
最终这些和的乘积即为
F
W
T
(
A
n
s
)
FWT(Ans)
FWT(Ans)
若将+看作0,-看作1,则符号乘法是与异或等价的。我们用一个数w来代表一种和。注意到
p
[
i
]
[
0
]
p[i][0]
p[i][0]已经被清零,可以知道a0在所有和的系数中都是1.因此w只需要
k
−
1
k-1
k−1个二进制位来表示即可。
然后,用
X
w
X_w
Xw来表示,在
F
W
T
(
A
n
s
)
FWT(Ans)
FWT(Ans)的某个特定位置中,
w
w
w这种和有多少个。
现求
F
W
T
(
X
)
FWT(X)
FWT(X)
这一步非常巧妙:枚举一个列
[
1
,
k
)
[1,k)
[1,k)中的子集S,对于一行,使一个空多项式G的第e个位置+1. 其中e是这一行中所有在S中的列的P的异或。然后,令
H
=
F
W
T
(
G
)
H=FWT(G)
H=FWT(G)。考虑这样的H有什么意义。
假如S中只有一个元素t,那么
H
[
i
]
H[i]
H[i]就是这一行
F
W
T
(
F
)
FWT(F)
FWT(F)在第i个位置上,给的是
+
a
[
t
]
+a[t]
+a[t]还是
−
a
[
t
]
-a[t]
−a[t]。不难证明,假如有多个元素,那么
H
[
i
]
H[i]
H[i]是单个t做出来的H的乘积。
再考虑上异或与乘法的等价关系,一种(在当前行累加到答案的第i个位置的)和
w
w
w在
H
[
i
]
H[i]
H[i]中的贡献,就是
(
−
1
)
∣
w
&
S
∣
(-1)^{|w\&S|}
(−1)∣w&S∣。
因为
∑
F
W
T
(
G
)
=
F
W
T
(
∑
G
)
\sum FWT(G)=FWT(\sum G)
∑FWT(G)=FWT(∑G),因此可以O(n)+一次FWT求出
D
=
∑
每
一
行
H
D=\sum_{每一行} H
D=∑每一行H。
那么就有,
D
[
i
]
=
∑
w
(
−
1
)
∣
w
&
S
∣
X
w
D[i]=\sum_w (-1)^{|w\&S|} X_w
D[i]=∑w(−1)∣w&S∣Xw。右侧正好是
F
W
T
(
X
)
[
S
]
FWT(X)[S]
FWT(X)[S]。
这™是仙术吧
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mo = 998244353, N = 1e6 + 10, M = 1 << 20;
int n, m, k;
ll a[N];
ll p[N][20];
void read(ll &x) {
char c; while(((c=getchar()))<'0'||c>'9');
x = c - '0';
while ((c=getchar()) >= '0' && c <= '9') x = x * 10 + c - '0';
}
ll ksm(ll x, ll y) {
ll ret = 1; for (; y; y >>= 1) {
if (y & 1) ret = ret * x % mo;
x = x * x % mo;
}
return ret;
}
ll ans[M];
vector<ll> f[M];
ll g[M];
void FWT(ll *a, ll e) {
int M = 1 << e;
for(int m = 1; m < M; m <<= 1) {
for(int i = 0; i < M; i += (m << 1)) {
for(int j = 0; j < m; j++) {
ll t = a[i + j + m];
a[i + j + m] = (a[i + j] - t) % mo;
a[i + j] = (a[i + j] + t) % mo;
}
}
}
}
void IFWT(ll *a, ll e) {
FWT(a, e);
ll ny = ksm(1 << e, mo - 2);
for(int i = 0; i < (1 << e); i++) a[i] = a[i] * ny % mo;
}
int main() {
freopen("yuyuko.in","r",stdin);
// freopen("yuyuko.out","w",stdout);
cin >> n >> m >> k;
for(int i = 0; i < k; i++) read(a[i]);
int sf = 0;
for(int i = 1; i <= n; i++) {
for(int j = 0; j < k; j++)
read(p[i][j]);
sf ^= p[i][0];
for(int j = 1; j < k; j++) p[i][j] ^= p[i][0];
p[i][0] = 0;
}
for(int i = 0; i < (1 << m); i++) {
f[i].resize(1 << k - 1);
}
for(int s = 0; s < (1 << k - 1); s++) {
memset(g, 0, (1 << m) * sizeof g[0]);
for(int i = 1; i <= n; i++) {
int xs = 0;
for(int j = 1; j < k; j++) if (s & (1 << j - 1)) {
xs ^= p[i][j];
}
g[xs]++;
}
FWT(g, m);
for(int i = 0; i < (1 << m); i++) f[i][s] = g[i];
}
static int sum[M];
for(int i = 0; i < (1 << k - 1); i++) {
sum[i] = a[0];
for(int z = 1; z < k; z++) {
if (i & (1 << z - 1)) {
sum[i] = (sum[i] - a[z]) % mo;
} else sum[i] = (sum[i] + a[z]) % mo;
}
}
static ll tmp[M];
for(int i = 0; i < (1 << m); i++) {
for(int j = 0; j < (1 << k - 1); j++) tmp[j] = f[i][j];
IFWT(tmp, k - 1);
ans[i] = 1;
for(int j = 0; j < (1 << k - 1); j++) {
ans[i] = ans[i] * ksm(sum[j], (tmp[j] + mo) % mo) % mo;
}
}
IFWT(ans, m);
for(int i = 0; i < (1 << m); i++) {
if ((sf ^ i) < i) swap(ans[i], ans[sf ^ i]);
}
for(int i = 0; i < (1 << m); i++)
printf("%lld ", (ans[i] + mo) % mo);
}