A All with Pairs(哈希+kmp)
把每个串的后缀哈希值用个桶记录一下,然后枚举每个串的前缀
s
s
s,能快速知道这个前缀出现了多少次。但是这个前缀不一定最长,这个很好解决,一个前缀
[
0
,
i
]
[0,i]
[0,i] 满足条件的时候,
[
0
,
n
e
x
t
i
−
1
]
[0,next_i-1]
[0,nexti−1] 一定也满足条件,只用在
n
e
x
t
i
next_i
nexti 处减掉这个前缀的贡献即可。
代码如下
#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
LL z = 1;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
int ksm(int a, int b, int p){
int s = 1;
while(b){
if(b & 1) s = z * s * a % p;
a = z * a * a % p;
b >>= 1;
}
return s;
}
const int N = 1e6 + 5, maxn = 1e6, mod = 998244353;
string s[N];
uLL hx, hx1[N], po[N], base = 131;
gp_hash_table<uLL, int, custom_hash> f;
int nxt[N];
LL cnt[N];
uLL gethash(uLL hx[], int l, int r){
return hx[r] - hx[l - 1] * po[r - l + 1];
}
int main(){
int i, j, n, m, t;
for(po[0] = i = 1; i <= maxn; i++) po[i] = po[i - 1] * base;
n = read();
for(i = 1; i <= n; i++){
cin >> s[i]; m = s[i].length(); hx = 0;
for(j = 1; j <= m; j++) hx1[j] = hx1[j - 1] * base + s[i][j - 1];
for(j = m; j >= 1; j--) f[gethash(hx1, j, m)]++;
}
LL ans = 0;
for(i = 1; i <= n; i++){
m = s[i].length(); t = 0;
for(j = 0; j < m; j++) nxt[j] = cnt[j] = 0;
t = 0;
for(j = 1; j < m; j++){
while(t && s[i][j] != s[i][t]) t = nxt[t - 1];
if(s[i][j] == s[i][t]) nxt[j] = (++t);
}
hx = 0;
for(j = 0; j < m; j++){
hx = hx * base + s[i][j];
cnt[j] = f[hx];
if(nxt[j] > 0) cnt[nxt[j] - 1] -= cnt[j];
}
for(j = 0; j < m; j++) ans = (ans + z * cnt[j] * (j + 1) * (j + 1) % mod) % mod;
}
printf("%d", ans);
return 0;
}
E Exclusive OR(FWT)
由于
A
i
<
2
18
A_i<2^{18}
Ai<218,因此线性基的秩不超过
18
18
18。
因此当
i
≥
20
i\ge 20
i≥20 时,
a
n
s
i
=
a
n
s
i
−
2
ans_i=ans_{i-2}
ansi=ansi−2。
我们只需要求前
19
19
19 项的值。
记
f
t
,
i
f_{t,i}
ft,i 表示用
t
t
t 个数时,异或值为
i
i
i 的方案数。
那么有一个简单的
d
p
dp
dp。
f
t
,
i
=
∑
j
⊕
k
=
i
f
t
−
1
,
j
×
f
1
,
k
f_{t,i}=\sum\limits_{j\oplus k=i}f_{t-1,j}\times f_{1,k}
ft,i=j⊕k=i∑ft−1,j×f1,k。
然后直接
f
w
t
fwt
fwt 即可。
代码如下
#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
LL z = 1;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
int ksm(int a, int b, int p){
int s = 1;
while(b){
if(b & 1) s = z * s * a % p;
a = z * a * a % p;
b >>= 1;
}
return s;
}
const int N = (1 << 18), mod = 1e9 + 7;
int a[N], b[N], ans[N], re[N], inv2, limit;
//f[i][j] = sum(f[i-1][k] * f[1][t])
void fwt(int *a, int type){
int mid, j, k, R, x, y;
for(mid = 1; mid < limit; mid <<= 1){
for(R = mid << 1, j = 0; j < limit; j += R){
for(k = 0; k < mid; k++){
x = a[j + k], y = a[j + k + mid];
a[j + k] = z * (x + y) * type % mod;
a[j + k + mid] = z * (x - y) * type % mod;
}
}
}
}
int main(){
int i, j, n, m, maxn = 0;
inv2 = ksm(2, mod - 2, mod);
n = read();
for(i = 1; i <= n; i++){
j = read();
a[j] = b[j] = re[j] = 1; maxn = max(maxn, j);
}
limit = (1 << 18);
for(i = 1; i <= n; i++){
if(i > 19) ans[i] = ans[i - 2];
else{
for(j = 0; j < limit; j++) b[j] = re[j];
if(i == 1) ans[i] = maxn;
else{
fwt(a, 1); fwt(b, 1);
for(j = 0; j < limit; j++) a[j] = z * a[j] * b[j] % mod;
fwt(a, inv2);
for(j = limit - 1; j >= 0; j--){
if(a[j]){
ans[i] = j;
break;
}
}
}
}
printf("%d ", ans[i]);
}
return 0;
}
G Greater and Greater(bitset)
这是一道很巧妙的
b
i
t
s
e
t
bitset
bitset 题。
设
s
i
,
j
=
1
s_{i,j}=1
si,j=1(
s
i
s_i
si 是个
b
i
t
s
e
t
bitset
bitset)表示
a
i
a_i
ai 比
b
j
b_j
bj 大。
其实并不用
n
n
n 个
b
i
t
s
e
t
bitset
bitset,因为本质不同的
s
i
s_i
si 只有
m
m
m 个,我们只用记录编号即可,这里的空间是
O
(
m
2
ω
)
O(\frac{m^2}{\omega})
O(ωm2)。
然后设
c
u
r
i
cur_i
curi 是一个
b
i
t
s
e
t
bitset
bitset,
c
u
r
i
,
j
=
1
cur_{i,j}=1
curi,j=1 表示
i
i
i 能和
j
j
j 匹配,来使得
[
i
,
i
+
m
−
j
]
[i,i+m-j]
[i,i+m−j] 每个数比
[
j
,
m
]
[j,m]
[j,m] 的每个数大,于是有一个明显的转移:
c
u
r
i
=
(
c
u
r
i
+
1
>
>
1
∣
I
m
)
&
s
i
cur_i=(cur_{i+1}>>1|I_m)\&s_i
curi=(curi+1>>1∣Im)&si。
然后从后往前滚就可以了。最后的复杂度是
O
(
n
m
w
)
O(\frac{nm}{w})
O(wnm) 的。
代码如下
#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
LL z = 1;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
int ksm(int a, int b, int p){
int s = 1;
while(b){
if(b & 1) s = z * s * a % p;
a = z * a * a % p;
b >>= 1;
}
return s;
}
bitset<40002> w[40002], cur;
const int N = 150005;
int a[N], b[N], s[N], id[N];
int cmp(int x, int y){
return b[x] < b[y];
}
int main(){
int i, j, n, m, l, r, mid, ans = 0;
n = read(); m = read();
for(i = 1; i <= n; i++) a[i] = read();
for(i = 1; i <= m; i++) b[i] = read(), id[i] = i;
sort(id + 1, id + m + 1, cmp);
sort(b + 1, b + m + 1);
for(i = 1; i <= m; i++){
w[i] = w[i - 1];
w[i][id[i]] = 1;
}
for(i = 1; i <= n; i++){
if(a[i] < b[1]){
s[i] = 0;
continue;
}
l = 1, r = m;
while(l < r){
mid = l + r + 1 >> 1;
if(b[mid] <= a[i]) l = mid;
else r = mid - 1;
}
s[i] = l;
}
for(i = n; i >= 1; i--){
cur >>= 1;
cur[m] = 1;
cur &= w[s[i]];
ans += (cur[1] == 1);
}
printf("%d", ans);
return 0;
}