题意:
给你一个01串s,问存在多少种数对<a,b><a,b><a,b>,使得a⊕b=sa \oplus b=sa⊕b=s。需满足a和b都是回文串,并且a和b的len<nlen<nlen<n。数据保证s的开头是1,若s[i]=?s[i]=?s[i]=?则该位置既可以是1,又可以使0。
题解:
思路来自官方题解。
本题可以转化为图论问题。首先考虑a和b的len都小于n,串s的开头一定是1,不妨假设b的开头是1,a的开头是0,那么a就是含有前导0的数,而b的长度一定为len。我们就可以枚举a的长度,最开头的数字是1。
1.保证回文:对于a串和b串,第i个与第n-i+1个点建边,边权为0,表示两数相等。
2.保证亦或后的结果:如果第i为是’1’则两点间建边,边权为1,表示不相等,如果为’0’,边权为0,表示相等。
接下来就是找图中连通块的个数。每个连通块的方案是两种,即0或1两种取反后有两种。
但是有问题:比如a和b的开头都是确定的,都是1,而且a的前导0也是给定的,这样很多连通块都是唯一的。
我们可以新加两个点,一个表示为0,一个表示为1。把所有已知的位置连向这两个点,这样已知点的都在同一个连通块里了。所以对于枚举的a的每一种情况,答案既是2C2^C2C,C即联通块的个数。
代码:
#include <bits/stdc++.h>
#define __ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
#define ll long long
#define mem(a, b) memset(a, b, sizeof a)
using namespace std;
const int maxn = 5e3 + 10;
const ll mod = 998244353;
ll pow_mod(ll a, ll b, ll m) {
ll ans = 1;
while (b) {
if (b & 1)ans = ans * a % m;
a = a * a % m;
b >>= 1;
}
return ans;
}
string s;
struct pxy {
int to, next, vi;
} e[maxn * 10];
int n, head[maxn], cnt, col[maxn];
void ins(int x, int y, int z) {
e[++cnt].to = y;
e[cnt].next = head[x];
head[x] = cnt;
e[cnt].vi = z;
e[++cnt].to = x;
e[cnt].next = head[y];
head[y] = cnt;
e[cnt].vi = z;
}
bool check(int x) {
if (col[x] == -1)col[x] = 0;
for (int i = head[x]; i; i = e[i].next) {
if (col[e[i].to] == -1) {
if (e[i].vi == 1) {
col[e[i].to] = 1 - col[x];
} else {
col[e[i].to] = col[x];
}
if (!check(e[i].to))return 0;
} else {
if (e[i].vi == 1) {
if (col[x] != 1 - col[e[i].to])return 0;
} else {
if (col[x] != col[e[i].to])return 0;
}
}
}
return 1;
}
ll solve(int n, int m) {
cnt = 0;
mem(head, 0);
mem(col, -1);
mem(e, 0);
for (int i = 1; i + i <= n; ++i) {
ins(i, n - i + 1, 0);
}
for (int i = 1; i + i <= m; ++i) {
ins(n + i, n + m - i + 1, 0);
}
for (int i = n + m + 1; i <= n + n; ++i) {
ins(i, n + n + 2, 0);
}
ins(n + n + 1, n + n + 2, 1);
for (int i = 1; i <= n; ++i) {
if (s[i - 1] == '1')ins(i, n + i, 1);
if (s[i - 1] == '0')ins(i, n + i, 0);
}
ins(n, n + n + 1, 0);
ins(n + m, n + n + 1, 0);
int c = 0;
for (int i = 1; i <= n + n + 2; ++i) {
if (col[i] == -1) {
if (check(i)) {
c++;
} else return 0;
}
}
return pow_mod(2, c - 1, mod);
}
int main() {
__;
cin >> s;
reverse(s.begin(), s.end());
n = s.size();
ll ans = 0;
for (int i = 1; i < n; ++i) {
ans += solve(n, i);
ans %= mod;
}
cout << ans << endl;
return 0;
}