A All with Pairs(哈希+kmp)
把每个串的后缀哈希值用个桶记录一下,然后枚举每个串的前缀 sss,能快速知道这个前缀出现了多少次。但是这个前缀不一定最长,这个很好解决,一个前缀 [0,i][0,i][0,i] 满足条件的时候,[0,nexti−1][0,next_i-1][0,nexti−1] 一定也满足条件,只用在 nextinext_inexti 处减掉这个前缀的贡献即可。
代码如下
#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
LL z = 1;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
int ksm(int a, int b, int p){
int s = 1;
while(b){
if(b & 1) s = z * s * a % p;
a = z * a * a % p;
b >>= 1;
}
return s;
}
const int N = 1e6 + 5, maxn = 1e6, mod = 998244353;
string s[N];
uLL hx, hx1[N], po[N], base = 131;
gp_hash_table<uLL, int, custom_hash> f;
int nxt[N];
LL cnt[N];
uLL gethash(uLL hx[], int l, int r){
return hx[r] - hx[l - 1] * po[r - l + 1];
}
int main(){
int i, j, n, m, t;
for(po[0] = i = 1; i <= maxn; i++) po[i] = po[i - 1] * base;
n = read();
for(i = 1; i <= n; i++){
cin >> s[i]; m = s[i].length(); hx = 0;
for(j = 1; j <= m; j++) hx1[j] = hx1[j - 1] * base + s[i][j - 1];
for(j = m; j >= 1; j--) f[gethash(hx1, j, m)]++;
}
LL ans = 0;
for(i = 1; i <= n; i++){
m = s[i].length(); t = 0;
for(j = 0; j < m; j++) nxt[j] = cnt[j] = 0;
t = 0;
for(j = 1; j < m; j++){
while(t && s[i][j] != s[i][t]) t = nxt[t - 1];
if(s[i][j] == s[i][t]) nxt[j] = (++t);
}
hx = 0;
for(j = 0; j < m; j++){
hx = hx * base + s[i][j];
cnt[j] = f[hx];
if(nxt[j] > 0) cnt[nxt[j] - 1] -= cnt[j];
}
for(j = 0; j < m; j++) ans = (ans + z * cnt[j] * (j + 1) * (j + 1) % mod) % mod;
}
printf("%d", ans);
return 0;
}
E Exclusive OR(FWT)
由于 Ai<218A_i<2^{18}Ai<218,因此线性基的秩不超过 181818。
因此当 i≥20i\ge 20i≥20 时,ansi=ansi−2ans_i=ans_{i-2}ansi=ansi−2。
我们只需要求前 191919 项的值。
记 ft,if_{t,i}ft,i 表示用 ttt 个数时,异或值为 iii 的方案数。
那么有一个简单的 dpdpdp。
ft,i=∑j⊕k=ift−1,j×f1,kf_{t,i}=\sum\limits_{j\oplus k=i}f_{t-1,j}\times f_{1,k}ft,i=j⊕k=i∑ft−1,j×f1,k。
然后直接 fwtfwtfwt 即可。
代码如下
#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
LL z = 1;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
int ksm(int a, int b, int p){
int s = 1;
while(b){
if(b & 1) s = z * s * a % p;
a = z * a * a % p;
b >>= 1;
}
return s;
}
const int N = (1 << 18), mod = 1e9 + 7;
int a[N], b[N], ans[N], re[N], inv2, limit;
//f[i][j] = sum(f[i-1][k] * f[1][t])
void fwt(int *a, int type){
int mid, j, k, R, x, y;
for(mid = 1; mid < limit; mid <<= 1){
for(R = mid << 1, j = 0; j < limit; j += R){
for(k = 0; k < mid; k++){
x = a[j + k], y = a[j + k + mid];
a[j + k] = z * (x + y) * type % mod;
a[j + k + mid] = z * (x - y) * type % mod;
}
}
}
}
int main(){
int i, j, n, m, maxn = 0;
inv2 = ksm(2, mod - 2, mod);
n = read();
for(i = 1; i <= n; i++){
j = read();
a[j] = b[j] = re[j] = 1; maxn = max(maxn, j);
}
limit = (1 << 18);
for(i = 1; i <= n; i++){
if(i > 19) ans[i] = ans[i - 2];
else{
for(j = 0; j < limit; j++) b[j] = re[j];
if(i == 1) ans[i] = maxn;
else{
fwt(a, 1); fwt(b, 1);
for(j = 0; j < limit; j++) a[j] = z * a[j] * b[j] % mod;
fwt(a, inv2);
for(j = limit - 1; j >= 0; j--){
if(a[j]){
ans[i] = j;
break;
}
}
}
}
printf("%d ", ans[i]);
}
return 0;
}
G Greater and Greater(bitset)
这是一道很巧妙的 bitsetbitsetbitset 题。
设 si,j=1s_{i,j}=1si,j=1(sis_isi 是个 bitsetbitsetbitset)表示 aia_iai 比 bjb_jbj 大。
其实并不用 nnn 个 bitsetbitsetbitset,因为本质不同的 sis_isi 只有 mmm 个,我们只用记录编号即可,这里的空间是 O(m2ω)O(\frac{m^2}{\omega})O(ωm2)。
然后设 curicur_icuri 是一个 bitsetbitsetbitset,curi,j=1cur_{i,j}=1curi,j=1 表示 iii 能和 jjj 匹配,来使得 [i,i+m−j][i,i+m-j][i,i+m−j] 每个数比 [j,m][j,m][j,m] 的每个数大,于是有一个明显的转移:curi=(curi+1>>1∣Im)&sicur_i=(cur_{i+1}>>1|I_m)\&s_icuri=(curi+1>>1∣Im)&si。
然后从后往前滚就可以了。最后的复杂度是 O(nmw)O(\frac{nm}{w})O(wnm) 的。
代码如下
#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
LL z = 1;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
int ksm(int a, int b, int p){
int s = 1;
while(b){
if(b & 1) s = z * s * a % p;
a = z * a * a % p;
b >>= 1;
}
return s;
}
bitset<40002> w[40002], cur;
const int N = 150005;
int a[N], b[N], s[N], id[N];
int cmp(int x, int y){
return b[x] < b[y];
}
int main(){
int i, j, n, m, l, r, mid, ans = 0;
n = read(); m = read();
for(i = 1; i <= n; i++) a[i] = read();
for(i = 1; i <= m; i++) b[i] = read(), id[i] = i;
sort(id + 1, id + m + 1, cmp);
sort(b + 1, b + m + 1);
for(i = 1; i <= m; i++){
w[i] = w[i - 1];
w[i][id[i]] = 1;
}
for(i = 1; i <= n; i++){
if(a[i] < b[1]){
s[i] = 0;
continue;
}
l = 1, r = m;
while(l < r){
mid = l + r + 1 >> 1;
if(b[mid] <= a[i]) l = mid;
else r = mid - 1;
}
s[i] = l;
}
for(i = n; i >= 1; i--){
cur >>= 1;
cur[m] = 1;
cur &= w[s[i]];
ans += (cur[1] == 1);
}
printf("%d", ans);
return 0;
}