添加链接描述
题意:
有A类字符和B类字符,有一个目标串。要求A类和B类字符串相互拼接,问拼成目标串至少需要多少个串。如果不能拼成,输出-1.
显然的要对
ABABAB
BABABA
的两种拼接方式取min
显然是可以转移的DP
dp[i][j] 为最后一个字符位置是 i ,用的是(1/0)A /B 类字符串。
转移方程
dp[i]j]=min(dp[i][j], dp[k][j^1]+1)
那么要如何较快的判断(i j )是否在A类或者B类。
对A每个字符串求哈希。时间复杂度加上A类字符串的长度(5e5)。对B类字符串求哈希,时间复杂度+B类字符串长度。(5e5)。
之后处理S的哈希数组。(5000)
然后dp 转移,是S长度的平方(5000*5000),在set 里面找,还有一个log 的复杂度。
哈希主要用在dp转移的时候,判断这个子串是否在A类或者B类
如果就这样写,会T。因为set 的插入和查询 常数很大。并且我们查询了很多不必要的子串(i j ),所以要减少find的次数。维护了一个bool 类型的长度数组。来判断A 或者B类中 有没有这个长度的串,之后再 find
附带一下我的勾丝代码。一开始写的单哈希wa了,后来改的双哈希。
我以前都是写hash 从下标1 开始存的hash ,写的时候脑子抽了,非要从0 开始存储。带来的结果就是 多了 l ==0 的特判,怪丑的。总之这个下标自己搞清楚就行。搞明白是从0开始还是从1 开始的。
#include <bits/stdc++.h>
using namespace std;
//#define ll long long
typedef long long ll;
#define pii pair<long long, long long>
const int N = 5000 + 10;
ll mod1 = 1e9 + 9;
ll mod2 = 1e9 + 7;
ll bas = 233;
ll p1[N];
ll p2[N];
ll ha1[N];
ll ha2[N];
const int M=5e5+5;
bool visa[M];
bool visb[M];
void init(int n)
{
p1[0] = 1;
p2[0] = 1;
for (int i = 1; i < n; i++)
{
p1[i] = p1[i - 1] * bas % mod1;
p2[i] = p2[i - 1] * bas % mod2;
}
}
int get1(int l, int r)
{
if (l == 0)
{
return ha1[r] % mod1;
}
return (ha1[r] - ha1[l - 1] * p1[r - l + 1] % mod1 + mod1) % mod1;
}
int get2(int l, int r)
{
if (l == 0)
{
return ha2[r] % mod2;
}
return (ha2[r] - ha2[l - 1] * p2[r - l + 1] % mod2 + mod2) % mod2;
}
pii get(int l, int r)
{
return {get1(l, r), get2(l, r)};
}
void solve()
{
int n;
cin >> n;
string s;
set<pii> sa;
while (n--)
{
cin >> s;
visa[s.size()]=true;
ll t1 = 0;
ll t2 = 0;
for (int i = 0; i < s.size(); i++)
{
t1 = t1 * bas % mod1 + s[i] - 'a' + 1;
t1 %= mod1;
t2 = t2 * bas % mod2 + s[i] - 'a' + 1;
t2 %= mod2;
}
sa.insert({t1, t2});
}
int m;
cin >> m;
set<pii> sb;
while (m--)
{
cin >> s;
visb[s.size()]=true;
ll t1 = 0;
ll t2 = 0;
for (int i = 0; i < s.size(); i++)
{
t1 = t1 * bas % mod1 + s[i] - 'a' + 1;
t1 %= mod1;
t2 = t2 * bas % mod2 + s[i] - 'a' + 1;
t2 %= mod2;
}
sb.insert({t1, t2});
}
cin >> s;
ll len = s.size();
init(len+1);
for (int i = 0; i < len; i++)
{
if (!i)
{
ha1[i] = (s[i] - 'a' + 1) % mod1;
ha2[i] = (s[i] - 'a' + 1) % mod2;
}
else
{
ha1[i] = (ha1[i - 1] * bas % mod1 + s[i] - 'a' + 1) % mod1;
ha2[i] = (ha2[i - 1] * bas % mod2 + s[i] - 'a' + 1) % mod2;
}
}
vector<vector<ll>> dp(len + 1, vector<ll>(2, 1e18));
// 0 代表 a
// 1代表b
// 以0 开始的处理
for (int i = 0; i < len; i++)
{
pii tar = get(0, i);
if (sa.find(tar) != sa.end())
{
dp[i][0] = 1;
}
if (sb.find(tar) != sb.end())
{
dp[i][1] = 1;
}
}
for (int i = 1; i < len; i++) // 结尾的
{
for (int j = 1; j <= i; j++)
{
pii tar = get(j, i);
// cout<<tar<<"\n";
if ( visa[i-j+1]&&sa.find(tar) != sa.end())
{
// cout<<"in1\n";
dp[i][0] = min(dp[i][0], dp[j - 1][1] + 1);
}
if (visb[i-j+1]&&sb.find(tar) != sb.end())
{
dp[i][1] = min(dp[i][1], dp[j - 1][0] + 1);
}
}
}
ll ans = min(dp[len - 1][0], dp[len - 1][1]);
cout << (ans == 1e18 ? -1 : ans) << "\n";
}
signed main()
{
std::cin.tie(nullptr)->sync_with_stdio(false);
int t = 1;
while (t--)
{
solve();
}
return 0;
}